示例#1
0
def test_execution_simple_fit():
    """
    Don't test for correctness, but check that everything actually executes
    """
    run_name = 'quickdirty'
    logging.info(60 * '-')
    logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-')
    logging.info(60 * '-')

    savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name)
    mkpath(savedir)
    data_filename = savedir + '{}_expectmax_{}_data.fits'.format(
        PY_VERS, run_name)
    log_filename = 'temp_data/{}_expectmax_{}/log.log'.format(
        PY_VERS, run_name)
    logging.basicConfig(level=logging.INFO,
                        filemode='w',
                        filename=log_filename)

    uniform_age = 1e-10
    sphere_comp_pars = np.array([
        # X, Y, Z, U, V, W, dX, dV,  age,
        [0, 0, 0, 0, 0, 0, 10., 5, uniform_age],
    ])
    starcount = 100

    background_density = 1e-9

    ncomps = sphere_comp_pars.shape[0]

    # true_memb_probs = np.zeros((starcount, ncomps))
    # true_memb_probs[:,0] = 1.

    synth_data = SynthData(
        pars=sphere_comp_pars,
        starcounts=[starcount],
        Components=SphereComponent,
        background_density=background_density,
    )
    synth_data.synthesise_everything()

    tabletool.convert_table_astro2cart(synth_data.table)
    background_count = len(synth_data.table) - starcount

    # insert background densities
    synth_data.table['background_log_overlap'] =\
        len(synth_data.table) * [np.log(background_density)]
    synth_data.table.write(data_filename, overwrite=True)

    origins = [SphereComponent(pars) for pars in sphere_comp_pars]

    best_comps, med_and_spans, memb_probs = \
        expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps,
                                 rdir=savedir, burnin=10, sampling_steps=10,
                                 trace_orbit_func=dummy_trace_orbit_func,
                                 use_background=True, ignore_stable_comps=False,
                                 max_em_iterations=200)
示例#2
0
def test_fit_stability_mixed_comps():
    """
    Have a fit with some iterations that have a mix of stable and
    unstable comps.

    TODO: Maybe give 2 similar comps tiny age but overlapping origins
    """

    run_name = 'mixed_stability'

    logging.info(60 * '-')
    logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-')
    logging.info(60 * '-')

    savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name)
    mkpath(savedir)
    data_filename = savedir + '{}_expectmax_{}_data.fits'.format(
        PY_VERS, run_name)
    log_filename = 'temp_data/{}_expectmax_{}/log.log'.format(
        PY_VERS, run_name)

    logging.basicConfig(level=logging.INFO,
                        filemode='w',
                        filename=log_filename)

    shared_cd_mean = np.zeros(6)
    tiny_age = 0.1
    medium_age = 10.

    #    origin_1 = traceorbit.trace_cartesian_orbit(shared_cd_mean, times=-medium_age)
    #    origin_2 = traceorbit.trace_cartesian_orbit(shared_cd_mean, times=-2*medium_age)
    #
    #    cd_mean_3 = np.array([-200,200,0,0,50,0.])
    #    origin_3 = traceorbit.trace_cartesian_orbit(cd_mean_3, times=-tiny_age)
    #
    #    sphere_comp_pars = np.array([
    #        #   X,  Y,  Z, U, V, W, dX, dV,  age,
    #        np.hstack((origin_1, 10., 5., medium_age)),   # Next two comps share a current day origin
    #        np.hstack((origin_2, 10., 5., 2*medium_age)), #  so hopefully will need several iterations to\
    #                                                      #  disentangle
    #         np.hstack((origin_3, 10., 5., tiny_age)),     # a distinct comp that is stable quickly
    #     ])
    uniform_age = 1e-10
    sphere_comp_pars = np.array([
        #   X,  Y,  Z, U, V, W, dX, dV,  age,
        [50, 0, 0, 0, 50, 0, 10., 5,
         uniform_age],  # Very distant (and stable) comp
        [0, -20, 0, 0, -5, 0, 10., 5, uniform_age],  # Overlapping comp 1
        [0, 20, 0, 0, 5, 0, 10., 5, uniform_age],  # Overlapping comp 2
    ])
    starcounts = [50, 100, 200]
    ncomps = sphere_comp_pars.shape[0]

    # initialise z appropriately
    true_memb_probs = np.zeros((np.sum(starcounts), ncomps))
    start = 0
    for i in range(ncomps):
        true_memb_probs[start:start + starcounts[i], i] = 1.0
        start += starcounts[i]

    # Initialise some random membership probablities
    #  which will serve as our starting guess
    init_memb_probs = np.random.rand(np.sum(starcounts), ncomps)
    # To aid a component in quickly becoming stable, initialse the memberships
    # correclty for stars belonging to this component
    init_memb_probs[:starcounts[0]] = 0.
    init_memb_probs[:starcounts[0], 0] = 1.
    init_memb_probs[starcounts[0]:, 0] = 0.

    # Normalising such that each row sums to 1
    init_memb_probs = (init_memb_probs.T / init_memb_probs.sum(axis=1)).T

    synth_data = SynthData(
        pars=sphere_comp_pars,
        starcounts=starcounts,
        Components=SphereComponent,
    )
    synth_data.synthesise_everything()
    tabletool.convert_table_astro2cart(synth_data.table,
                                       write_table=True,
                                       filename=data_filename)

    origins = [SphereComponent(pars) for pars in sphere_comp_pars]
    SphereComponent.store_raw_components(savedir + 'origins.npy', origins)

    best_comps, med_and_spans, memb_probs = \
        expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps,
                                 rdir=savedir, init_memb_probs=init_memb_probs,
                                 trace_orbit_func=dummy_trace_orbit_func,
                                 ignore_stable_comps=True)

    perm = expectmax.get_best_permutation(memb_probs, true_memb_probs)

    logging.info('Best permutation is: {}'.format(perm))

    # Calculate the membership difference, we divide by 2 since
    # incorrectly allocated stars are double counted
    total_diff = 0.5 * np.sum(np.abs(true_memb_probs - memb_probs[:, perm]))

    # Assert that expected membership is less than 10%
    assert total_diff < 0.1 * np.sum(starcounts)

    for origin, best_comp in zip(origins, np.array(best_comps)[perm, ]):
        assert (isinstance(origin, SphereComponent)
                and isinstance(best_comp, SphereComponent))
        o_pars = origin.get_pars()
        b_pars = best_comp.get_pars()

        logging.info("origin pars:   {}".format(o_pars))
        logging.info("best fit pars: {}".format(b_pars))
        assert np.allclose(origin.get_mean(), best_comp.get_mean(), atol=5.)
        assert np.allclose(origin.get_sphere_dx(),
                           best_comp.get_sphere_dx(),
                           atol=2.)
        assert np.allclose(origin.get_sphere_dv(),
                           best_comp.get_sphere_dv(),
                           atol=2.)
        assert np.allclose(origin.get_age(), best_comp.get_age(), atol=1.)
示例#3
0
def test_fit_many_comps():
    """
    Synthesise a file with negligible error, retrieve initial
    parameters

    Takes a while... maybe this belongs in integration unit_tests
    """

    run_name = 'stationary'

    logging.info(60 * '-')
    logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-')
    logging.info(60 * '-')

    savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name)
    mkpath(savedir)
    data_filename = savedir + '{}_expectmax_{}_data.fits'.format(
        PY_VERS, run_name)
    log_filename = 'temp_data/{}_expectmax_{}/log.log'.format(
        PY_VERS, run_name)

    logging.basicConfig(level=logging.INFO,
                        filemode='w',
                        filename=log_filename)
    uniform_age = 1e-10
    sphere_comp_pars = np.array([
        #   X,  Y,  Z, U, V, W, dX, dV,  age,
        [-50, -50, -50, 0, 0, 0, 10., 5, uniform_age],
        [50, 50, 50, 0, 0, 0, 10., 5, uniform_age],
    ])
    starcounts = [20, 50]
    ncomps = sphere_comp_pars.shape[0]

    # initialise z appropriately
    true_memb_probs = np.zeros((np.sum(starcounts), ncomps))
    start = 0
    for i in range(ncomps):
        true_memb_probs[start:start + starcounts[i], i] = 1.0
        start += starcounts[i]

    # Initialise some random membership probablities
    # Normalising such that each row sums to 1
    init_memb_probs = np.random.rand(np.sum(starcounts), ncomps)
    init_memb_probs = (init_memb_probs.T / init_memb_probs.sum(axis=1)).T

    synth_data = SynthData(
        pars=sphere_comp_pars,
        starcounts=starcounts,
        Components=SphereComponent,
    )
    synth_data.synthesise_everything()
    tabletool.convert_table_astro2cart(synth_data.table,
                                       write_table=True,
                                       filename=data_filename)

    origins = [SphereComponent(pars) for pars in sphere_comp_pars]

    best_comps, med_and_spans, memb_probs = \
        expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps,
                                 rdir=savedir, init_memb_probs=init_memb_probs,
                                 trace_orbit_func=dummy_trace_orbit_func,
                                 ignore_stable_comps=False)

    perm = expectmax.get_best_permutation(memb_probs, true_memb_probs)

    logging.info('Best permutation is: {}'.format(perm))

    assert np.allclose(true_memb_probs, memb_probs[:, perm])

    for origin, best_comp in zip(origins, np.array(best_comps)[perm, ]):
        assert (isinstance(origin, SphereComponent)
                and isinstance(best_comp, SphereComponent))
        o_pars = origin.get_pars()
        b_pars = best_comp.get_pars()

        logging.info("origin pars:   {}".format(o_pars))
        logging.info("best fit pars: {}".format(b_pars))
        assert np.allclose(origin.get_mean(), best_comp.get_mean(), atol=5.)
        assert np.allclose(origin.get_sphere_dx(),
                           best_comp.get_sphere_dx(),
                           atol=2.)
        assert np.allclose(origin.get_sphere_dv(),
                           best_comp.get_sphere_dv(),
                           atol=2.)
        assert np.allclose(origin.get_age(), best_comp.get_age(), atol=1.)
示例#4
0
def test_fit_one_comp_with_background():
    """
    Synthesise a file with negligible error, retrieve initial
    parameters

    Takes a while...

    Parameters
    ----------

    """
    run_name = 'background'

    logging.info(60 * '-')
    logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-')
    logging.info(60 * '-')

    savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name)
    mkpath(savedir)
    data_filename = savedir + '{}_expectmax_{}_data.fits'.format(
        PY_VERS, run_name)
    log_filename = 'temp_data/{}_expectmax_{}/log.log'.format(
        PY_VERS, run_name)

    logging.basicConfig(level=logging.INFO,
                        filemode='w',
                        filename=log_filename)
    uniform_age = 1e-10
    sphere_comp_pars = np.array([
        # X, Y, Z, U, V, W, dX, dV,  age,
        [0, 0, 0, 0, 0, 0, 10., 5, uniform_age],
    ])
    starcount = 200

    background_density = 1e-9

    ncomps = sphere_comp_pars.shape[0]

    # true_memb_probs = np.zeros((starcount, ncomps))
    # true_memb_probs[:,0] = 1.

    synth_data = SynthData(
        pars=sphere_comp_pars,
        starcounts=[starcount],
        Components=SphereComponent,
        background_density=background_density,
    )
    synth_data.synthesise_everything()

    tabletool.convert_table_astro2cart(synth_data.table)
    background_count = len(synth_data.table) - starcount
    logging.info('Generated {} background stars'.format(background_count))

    # insert background densities
    synth_data.table['background_log_overlap'] =\
        len(synth_data.table) * [np.log(background_density)]
    synth_data.table.write(data_filename, overwrite=True)

    origins = [SphereComponent(pars) for pars in sphere_comp_pars]

    best_comps, med_and_spans, memb_probs = \
        expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps,
                                 rdir=savedir, burnin=500, sampling_steps=5000,
                                 trace_orbit_func=dummy_trace_orbit_func,
                                 use_background=True, ignore_stable_comps=False,
                                 max_em_iterations=200)

    # return best_comps, med_and_spans, memb_probs

    # Check parameters are close
    assert np.allclose(sphere_comp_pars, best_comps[0].get_pars(), atol=1.5)

    # Check most assoc members are correctly classified
    recovery_count_threshold = 0.95 * starcount
    recovery_count_actual = np.sum(memb_probs[:starcount, 0] > 0.5)
    assert recovery_count_threshold < recovery_count_actual

    # Check most background stars are correctly classified
    # Number of bg stars classified as members should be less than 5%
    # of all background stars
    contamination_count_threshold = 0.05 * len(memb_probs[starcount:])
    contamination_count_actual = np.sum(memb_probs[starcount:, 0] > 0.5)
    assert contamination_count_threshold > contamination_count_actual

    # Check reported membership probabilities are consistent with recovery
    # rate (within 5%)
    mean_membership_confidence = np.mean(memb_probs[:starcount, 0])
    assert np.isclose(recovery_count_actual / starcount,
                      mean_membership_confidence,
                      atol=0.05)
示例#5
0
    # handle special case of one component
    if ncomps == 1:
        logging.info("******************************************")
        logging.info("*********  FITTING 1 COMPONENT  **********")
        logging.info("******************************************")
        run_dir = rdir + '{}/'.format(ncomps)
        mkpath(run_dir)

        try:
            new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy')
            new_meds = np.load(run_dir + 'final/final_med_errs.npy')
            new_z = np.load(run_dir + 'final/final_membership.npy')
            logging.info("Loaded from previous run")
        except IOError:
            new_groups, new_meds, new_z = \
                em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool,
                                  bg_ln_ols=bg_ln_ols) # kill component here
            new_groups = np.array(new_groups)

        new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols)
        new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols,
                                                 inc_posterior=True)
        new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike)
    # handle multiple components
    else:
        logging.info("******************************************")
        logging.info("*********  FITTING {} COMPONENTS  *********".\
                     format(ncomps))
        logging.info("******************************************")
        best_fits = []
all_xyzuvw_init, origins = syn.synthesiseManyXYZUVW(
    group_pars, sphere=True, return_groups=True
)
logging.info(" done")
    # xyzuvw_init, group =\
    #     syn.synthesiseXYZUVW(group_pars, sphere=True, return_group=True)
    # all_xyzuvw_init = np.vstack((all_xyzuvw_init, xyzuvw_init))
    #
    # xyzuvw_now_perf = torb.traceManyOrbitXYZUVW(xyzuvw_init, group.age,
    #                                             single_age=True)
    # all_xyzuvw_now_perf =\
    #     np.vstack((all_xyzuvw_now_perf, xyzuvw_now_perf))
    # origins.append(group)

logging.info("Saving synthetic data...")
np.save(rdir+groups_savefile, origins)
np.save(rdir+xyzuvw_perf_file, all_xyzuvw_now_perf)
astro_table = chronostar.synthdata.measureXYZUVW(all_xyzuvw_now_perf, 1.0,
                                                 savefile=rdir+astro_savefile)

star_pars = cv.convertMeasurementsToCartesian(
    astro_table, savefile=rdir+xyzuvw_conv_savefile,
)
em.fit_many_comps(star_pars, ngroups, origins=origins,
                  rdir=rdir, pool=pool,
                  #init_with_origin=True
                  )

if using_mpi:
    pool.close()
示例#7
0
    if ncomps == 1:
        logging.info("******************************************")
        logging.info("*********  FITTING 1 COMPONENT  **********")
        logging.info("******************************************")
        run_dir = rdir + '{}/'.format(ncomps)
        mkpath(run_dir)

        try:
            new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy')
            new_meds = np.load(run_dir + 'final/final_med_errs.npy')
            new_z = np.load(run_dir + 'final/final_membership.npy')
            logging.info("Loaded from previous run")
        except IOError:
            new_groups, new_meds, new_z =\
                em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool,
                                  bg_dens=BG_DENS,
                                  )
            new_groups = np.array(new_groups)

        new_lnlike = em.get_overall_lnlikelihood(star_pars,
                                                 new_groups,
                                                 bg_ln_ols=bg_ln_ols)
        new_lnpost = em.get_overall_lnlikelihood(star_pars,
                                                 new_groups,
                                                 bg_ln_ols=bg_ln_ols,
                                                 inc_posterior=True)
        new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike)
    # handle multiple components
    else:
        logging.info("******************************************")
        logging.info("*********  FITTING {} COMPONENTS  *********".\
            final_cdir = run_dir + 'final/comp{}/'.format(i)
            chain = np.load(final_cdir + 'final_chain.npy')
            lnprob = np.load(final_cdir + 'final_lnprob.npy')
            npars = len(Component.PARAMETER_FORMAT)
            best_ix = np.argmax(lnprob)
            best_pars = chain.reshape(-1, npars)[best_ix]
            prev_comps[i] = Component(emcee_pars=best_pars)
        np.save(str(run_dir + 'final/' + final_comps_file), prev_comps)

    logging.info('Loaded from previous run')
except IOError:
    prev_comps, prev_med_and_spans, prev_memb_probs = \
        expectmax.fit_many_comps(data=data_dict, ncomps=ncomps, rdir=run_dir,
                                 trace_orbit_func=trace_orbit_func,
                                 burnin=config.advanced['burnin_steps'],
                                 sampling_steps=config.advanced['sampling_steps'],
                                 use_background=config.config[
                                    'include_background_distribution'],
                                 init_memb_probs=init_memb_probs,
                                 )

# Calculate global score of fit for comparison with future fits with different
# component counts
prev_lnlike = expectmax.get_overall_lnlikelihood(
    data_dict,
    prev_comps,
    # bg_ln_ols=bg_ln_ols,
)
prev_lnpost = expectmax.get_overall_lnlikelihood(
    data_dict,
    prev_comps,
    # bg_ln_ols=bg_ln_ols,
示例#9
0
def test_fit_one_comp_with_background():
    """
    Synthesise a file with negligible error, retrieve initial
    parameters

    Takes a while... maybe this belongs in integration unit_tests
    """
    run_name = 'background'
    savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name)
    mkpath(savedir)
    data_filename = savedir + '{}_expectmax_{}_data.fits'.format(PY_VERS,
                                                                 run_name)
    # log_filename = 'temp_data/{}_expectmax_{}/log.log'.format(PY_VERS,
    #                                                           run_name)

    logging.basicConfig(level=logging.INFO, filemode='w',
                        filename=log_filename)
    uniform_age = 1e-10
    sphere_comp_pars = np.array([
        # X, Y, Z, U, V, W, dX, dV,  age,
        [ 0, 0, 0, 0, 0, 0, 10.,  5, uniform_age],
    ])
    starcount = 100

    background_density = 1e-9

    ncomps = sphere_comp_pars.shape[0]

    # true_memb_probs = np.zeros((starcount, ncomps))
    # true_memb_probs[:,0] = 1.

    synth_data = SynthData(pars=sphere_comp_pars, starcounts=[starcount],
                           Components=SphereComponent,
                           background_density=background_density,
                           )
    synth_data.synthesise_everything()

    tabletool.convert_table_astro2cart(synth_data.table,
                                       write_table=True,
                                       filename=data_filename)
    background_count = len(synth_data.table) - starcount

    # insert background densities
    synth_data.table['background_log_overlap'] =\
        len(synth_data.table) * [np.log(background_density)]

    origins = [SphereComponent(pars) for pars in sphere_comp_pars]

    best_comps, med_and_spans, memb_probs = \
        expectmax.fit_many_comps(data=synth_data.table,
                                 ncomps=ncomps,
                                 rdir=savedir,
                                 trace_orbit_func=dummy_trace_orbit_func,
                                 use_background=True)

    return best_comps, med_and_spans, memb_probs

    # Check parameters are close
    assert np.allclose(sphere_comp_pars, best_comps[0].get_pars(),
                       atol=1.)

    # Check most assoc members are correctly classified
    recovery_count_threshold = 0.95 * starcounts[0]
    recovery_count_actual =  np.sum(np.round(memb_probs[:starcount,0]))
    assert recovery_count_threshold < recovery_count_actual

    # Check most background stars are correctly classified
    contamination_count_threshold = 0.05 * len(memb_probs[100:])
    contamination_count_actual = np.sum(np.round(memb_probs[starcount:,0]))
    assert contamination_count_threshold < contamination_count_actual

    # Check reported membership probabilities are consistent with recovery
    # rate (within 5%)
    mean_membership_confidence = np.mean(memb_probs[:starcount,0])
    assert np.isclose(recovery_count_actual/100., mean_membership_confidence,
                      atol=0.05)
示例#10
0
def test_fit_many_comps():
    """
    Synthesise a file with negligible error, retrieve initial
    parameters

    Takes a while... maybe this belongs in integration unit_tests
    """

    run_name = 'stationary'
    savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name)
    mkpath(savedir)
    data_filename = savedir + '{}_expectmax_{}_data.fits'.format(PY_VERS,
                                                                 run_name)
    # log_filename = 'temp_data/{}_expectmax_{}/log.log'.format(PY_VERS,
    #                                                           run_name)

    logging.basicConfig(level=logging.INFO, filemode='w',
                        filename=log_filename)
    uniform_age = 1e-10
    sphere_comp_pars = np.array([
        #  X,  Y,  Z, U, V, W, dX, dV,  age,
        [-50,-50,-50, 0, 0, 0, 10.,  5, uniform_age],
        [ 50, 50, 50, 0, 0, 0, 10.,  5, uniform_age],
    ])
    starcounts = [200,200]
    ncomps = sphere_comp_pars.shape[0]

    # initialise z appropriately
    # start = 0
    # for i in range(ngroups):
    #     nstars_in_group = int(group_pars[i,-1])
    #     z[start:start+nstars_in_group,i] = 1.0
    #     start += nstars_in_group

    true_memb_probs = np.zeros((np.sum(starcounts), ncomps))
    true_memb_probs[:200,0] = 1.
    true_memb_probs[200:,1] = 1.

    synth_data = SynthData(pars=sphere_comp_pars, starcounts=starcounts,
                           Components=SphereComponent,
                           )
    synth_data.synthesise_everything()
    tabletool.convert_table_astro2cart(synth_data.table,
                                       write_table=True,
                                       filename=data_filename)

    origins = [SphereComponent(pars) for pars in sphere_comp_pars]

    best_comps, med_and_spans, memb_probs = \
        expectmax.fit_many_comps(data=synth_data.table,
                                 ncomps=ncomps,
                                 rdir=savedir,
                                 trace_orbit_func=dummy_trace_orbit_func, )

    # compare fit with input
    try:
        assert np.allclose(true_memb_probs, memb_probs)
    except AssertionError:
        # If not close, check if flipping component order fixes things
        memb_probs = memb_probs[:,::-1]
        best_comps = best_comps[::-1]
        assert np.allclose(true_memb_probs, memb_probs)
    for origin, best_comp in zip(origins, best_comps):
        assert (isinstance(origin, SphereComponent) and
                isinstance(best_comp, SphereComponent))
        o_pars = origin.get_pars()
        b_pars = best_comp.get_pars()

        logging.info("origin pars:   {}".format(o_pars))
        logging.info("best fit pars: {}".format(b_pars))
        assert np.allclose(origin.get_mean(),
                           best_comp.get_mean(),
                           atol=5.)
        assert np.allclose(origin.get_sphere_dx(),
                           best_comp.get_sphere_dx(),
                           atol=2.)
        assert np.allclose(origin.get_sphere_dv(),
                           best_comp.get_sphere_dv(),
                           atol=2.)
        assert np.allclose(origin.get_age(),
                           best_comp.get_age(),
                           atol=1.)
    bp_cov = np.cov(star_means.T)
    bp_dx = np.sqrt(np.min([bp_cov[0,0], bp_cov[1,1], bp_cov[2,2]]))
    bp_dv = np.sqrt(np.min([bp_cov[3,3], bp_cov[4,4], bp_cov[5,5]]))
    bp_age = 0.5
    bp_pars = np.hstack((bp_mean, bp_dx, bp_dv, bp_age, nstars))
    bp_group = chronostar.component.Component(bp_pars)
    origins = [bp_group]

#go through and compare overlap with groups with
#background overlap
#
#bp_pars = np.array([
#   2.98170398e+01,  4.43573995e+01,  2.29251498e+01, -9.65731744e-01,
#   -3.42827894e+00, -3.99928052e-02 , 2.63084094e+00,  1.05302890e-01,
#   1.59367119e+01, nstars
#])
#bp_group = syn.Group(bp_pars)
#
# --------------------------------------------------------------------------
# Run fit
# --------------------------------------------------------------------------
logging.info("Using data file {}".format(xyzuvw_file))
logging.info("Everything loaded, about to fit with {} components"\
    .format(NGROUPS))
em.fit_many_comps(star_pars, NGROUPS,
                  rdir=rdir, pool=pool, offset=True, bg_hist_file=bg_hist_file,
                  origins=origins, init_with_origin=init_origin,
                  )
if using_mpi:
    pool.close()
示例#12
0
            best_ix = np.argmax(lnprob)
            best_pars = chain.reshape(-1, npars)[best_ix]
            prev_comps[i] = Component(emcee_pars=best_pars)
        Component.store_raw_components(
            str(run_dir + 'final/' + final_comps_file), prev_comps)
        # np.save(str(run_dir+'final/'+final_comps_file), prev_comps)

    logging.info('Loaded from previous run')
except IOError:
    prev_comps, prev_med_and_spans, prev_memb_probs = \
        expectmax.fit_many_comps(data=data_dict, ncomps=ncomps, rdir=run_dir,
                                 trace_orbit_func=trace_orbit_func,
                                 burnin=config.advanced['burnin_steps'],
                                 sampling_steps=config.advanced['sampling_steps'],
                                 use_background=config.config[
                                    'include_background_distribution'],
                                 init_memb_probs=init_memb_probs,
                                 Component=Component,
                                 store_burnin_chains=store_burnin_chains,
                                 max_iters=MAX_ITERS,
                                 )

# Calculate global score of fit for comparison with future fits with different
# component counts
prev_lnlike = expectmax.get_overall_lnlikelihood(
    data_dict,
    prev_comps,
    # bg_ln_ols=bg_ln_ols,
)
prev_lnpost = expectmax.get_overall_lnlikelihood(
    data_dict,
    # handle special case of one component
    if ncomps == 1:
        logging.info("******************************************")
        logging.info("*********  FITTING 1 COMPONENT  **********")
        logging.info("******************************************")
        run_dir = rdir + '{}/'.format(ncomps)
        mkpath(run_dir)

        try:
            new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy')
            new_meds = np.load(run_dir + 'final/final_med_errs.npy')
            new_z = np.load(run_dir + 'final/final_membership.npy')
            logging.info("Loaded from previous run")
        except IOError:
            new_groups, new_meds, new_z = \
                em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool,
                                  bg_ln_ols=bg_ln_ols)
            new_groups = np.array(new_groups)

        new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols)
        new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols,
                                                 inc_posterior=True)
        new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike)
    # handle multiple components
    else:
        logging.info("******************************************")
        logging.info("*********  FITTING {} COMPONENTS  *********".\
                     format(ncomps))
        logging.info("******************************************")
        best_fits = []
            final_cdir = run_dir + 'final/comp{}/'.format(i)
            chain = np.load(final_cdir + 'final_chain.npy')
            lnprob = np.load(final_cdir + 'final_lnprob.npy')
            npars = len(Component.PARAMETER_FORMAT)
            best_ix = np.argmax(lnprob)
            best_pars = chain.reshape(-1,npars)[best_ix]
            prev_comps[i] = Component(emcee_pars=best_pars)
        np.save(str(run_dir+'final/'+final_comps_file), prev_comps)

    logging.info('Loaded from previous run')
except IOError:
    prev_comps, prev_med_and_spans, prev_memb_probs = \
        expectmax.fit_many_comps(data=data_dict, ncomps=ncomps, rdir=run_dir,
                                 trace_orbit_func=trace_orbit_func,
                                 burnin=config.advanced['burnin_steps'],
                                 sampling_steps=config.advanced['sampling_steps'],
                                 use_background=config.config[
                                    'include_background_distribution'],
                                 init_memb_probs=init_memb_probs,
                                 )


# Calculate global score of fit for comparison with future fits with different
# component counts
prev_lnlike = expectmax.get_overall_lnlikelihood(data_dict, prev_comps,
                                                 # bg_ln_ols=bg_ln_ols,
                                                 )
prev_lnpost = expectmax.get_overall_lnlikelihood(data_dict, prev_comps,
                                                 # bg_ln_ols=bg_ln_ols,
                                                 inc_posterior=True)
prev_bic = expectmax.calc_bic(data_dict, ncomps, prev_lnlike)
init_z=init_z.T

print(np.sum(init_z, axis=0))

ncomps=len(np.sum(init_z, axis=0))-1
print('ncomps: %d'%ncomps)

print('%d components'%(len(np.sum(init_z, axis=0))-1))
print('init_z successful!! Yey')
# --------------------------------------------------------------------------
# Perform one EM fit
# --------------------------------------------------------------------------
logging.info("Using data file {}".format(xyzuvw_file))
logging.info("Everything loaded, about to fit with {} components"\
    .format(ncomps))
final_groups, final_med_errs, z = em.fit_many_comps(star_pars, ncomps, bg_ln_ols=bg_ln_ols,
                                                    rdir=rdir, pool=pool, init_memb_probs=init_z, ignore_dead_comps=True)


# --------------------------------------------------------------------------
# Repeat EM fit but killing components with too few members
# --------------------------------------------------------------------------
# e.g. maybe < 3 expected star count
# take results from entire past EM fit


#
# init_origin = False
# origins = None
# if NGROUPS == 1:
#     init_origin = True
#     nstars = star_means.shape[0]
    if ncomps == 1:
        logging.info("******************************************")
        logging.info("*********  FITTING 1 COMPONENT  **********")
        logging.info("******************************************")
        run_dir = rdir + '{}/'.format(ncomps)
        mkpath(run_dir)

        try:
            new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy')
            new_meds = np.load(run_dir + 'final/final_med_errs.npy')
            new_z = np.load(run_dir + 'final/final_membership.npy')
            logging.info("Loaded from previous run")
        except IOError:
            new_groups, new_meds, new_z =\
                em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool,
                                  bg_dens=BG_DENS,
                                  )
            new_groups = np.array(new_groups)

        new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols)
        new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols,
                                                 inc_posterior=True)
        new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike)
    # handle multiple components
    else:
        logging.info("******************************************")
        logging.info("*********  FITTING {} COMPONENTS  *********".\
                     format(ncomps))
        logging.info("******************************************")
#
# xyzuvw_now_perf = torb.traceManyOrbitXYZUVW(xyzuvw_init, group.age,
#                                             single_age=True)
# all_xyzuvw_now_perf =\
#     np.vstack((all_xyzuvw_now_perf, xyzuvw_now_perf))
# origins.append(group)

logging.info("Saving synthetic data...")
np.save(rdir + groups_savefile, origins)
np.save(rdir + xyzuvw_perf_file, all_xyzuvw_now_perf)
astro_table = chronostar.synthdata.measureXYZUVW(all_xyzuvw_now_perf,
                                                 1.0,
                                                 savefile=rdir +
                                                 astro_savefile)

star_pars = cv.convertMeasurementsToCartesian(
    astro_table,
    savefile=rdir + xyzuvw_conv_savefile,
)
em.fit_many_comps(
    star_pars,
    ngroups,
    origins=origins,
    rdir=rdir,
    pool=pool,
    #init_with_origin=True
)

if using_mpi:
    pool.close()