def test_execution_simple_fit(): """ Don't test for correctness, but check that everything actually executes """ run_name = 'quickdirty' logging.info(60 * '-') logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-') logging.info(60 * '-') savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name) mkpath(savedir) data_filename = savedir + '{}_expectmax_{}_data.fits'.format( PY_VERS, run_name) log_filename = 'temp_data/{}_expectmax_{}/log.log'.format( PY_VERS, run_name) logging.basicConfig(level=logging.INFO, filemode='w', filename=log_filename) uniform_age = 1e-10 sphere_comp_pars = np.array([ # X, Y, Z, U, V, W, dX, dV, age, [0, 0, 0, 0, 0, 0, 10., 5, uniform_age], ]) starcount = 100 background_density = 1e-9 ncomps = sphere_comp_pars.shape[0] # true_memb_probs = np.zeros((starcount, ncomps)) # true_memb_probs[:,0] = 1. synth_data = SynthData( pars=sphere_comp_pars, starcounts=[starcount], Components=SphereComponent, background_density=background_density, ) synth_data.synthesise_everything() tabletool.convert_table_astro2cart(synth_data.table) background_count = len(synth_data.table) - starcount # insert background densities synth_data.table['background_log_overlap'] =\ len(synth_data.table) * [np.log(background_density)] synth_data.table.write(data_filename, overwrite=True) origins = [SphereComponent(pars) for pars in sphere_comp_pars] best_comps, med_and_spans, memb_probs = \ expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps, rdir=savedir, burnin=10, sampling_steps=10, trace_orbit_func=dummy_trace_orbit_func, use_background=True, ignore_stable_comps=False, max_em_iterations=200)
def test_fit_stability_mixed_comps(): """ Have a fit with some iterations that have a mix of stable and unstable comps. TODO: Maybe give 2 similar comps tiny age but overlapping origins """ run_name = 'mixed_stability' logging.info(60 * '-') logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-') logging.info(60 * '-') savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name) mkpath(savedir) data_filename = savedir + '{}_expectmax_{}_data.fits'.format( PY_VERS, run_name) log_filename = 'temp_data/{}_expectmax_{}/log.log'.format( PY_VERS, run_name) logging.basicConfig(level=logging.INFO, filemode='w', filename=log_filename) shared_cd_mean = np.zeros(6) tiny_age = 0.1 medium_age = 10. # origin_1 = traceorbit.trace_cartesian_orbit(shared_cd_mean, times=-medium_age) # origin_2 = traceorbit.trace_cartesian_orbit(shared_cd_mean, times=-2*medium_age) # # cd_mean_3 = np.array([-200,200,0,0,50,0.]) # origin_3 = traceorbit.trace_cartesian_orbit(cd_mean_3, times=-tiny_age) # # sphere_comp_pars = np.array([ # # X, Y, Z, U, V, W, dX, dV, age, # np.hstack((origin_1, 10., 5., medium_age)), # Next two comps share a current day origin # np.hstack((origin_2, 10., 5., 2*medium_age)), # so hopefully will need several iterations to\ # # disentangle # np.hstack((origin_3, 10., 5., tiny_age)), # a distinct comp that is stable quickly # ]) uniform_age = 1e-10 sphere_comp_pars = np.array([ # X, Y, Z, U, V, W, dX, dV, age, [50, 0, 0, 0, 50, 0, 10., 5, uniform_age], # Very distant (and stable) comp [0, -20, 0, 0, -5, 0, 10., 5, uniform_age], # Overlapping comp 1 [0, 20, 0, 0, 5, 0, 10., 5, uniform_age], # Overlapping comp 2 ]) starcounts = [50, 100, 200] ncomps = sphere_comp_pars.shape[0] # initialise z appropriately true_memb_probs = np.zeros((np.sum(starcounts), ncomps)) start = 0 for i in range(ncomps): true_memb_probs[start:start + starcounts[i], i] = 1.0 start += starcounts[i] # Initialise some random membership probablities # which will serve as our starting guess init_memb_probs = np.random.rand(np.sum(starcounts), ncomps) # To aid a component in quickly becoming stable, initialse the memberships # correclty for stars belonging to this component init_memb_probs[:starcounts[0]] = 0. init_memb_probs[:starcounts[0], 0] = 1. init_memb_probs[starcounts[0]:, 0] = 0. # Normalising such that each row sums to 1 init_memb_probs = (init_memb_probs.T / init_memb_probs.sum(axis=1)).T synth_data = SynthData( pars=sphere_comp_pars, starcounts=starcounts, Components=SphereComponent, ) synth_data.synthesise_everything() tabletool.convert_table_astro2cart(synth_data.table, write_table=True, filename=data_filename) origins = [SphereComponent(pars) for pars in sphere_comp_pars] SphereComponent.store_raw_components(savedir + 'origins.npy', origins) best_comps, med_and_spans, memb_probs = \ expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps, rdir=savedir, init_memb_probs=init_memb_probs, trace_orbit_func=dummy_trace_orbit_func, ignore_stable_comps=True) perm = expectmax.get_best_permutation(memb_probs, true_memb_probs) logging.info('Best permutation is: {}'.format(perm)) # Calculate the membership difference, we divide by 2 since # incorrectly allocated stars are double counted total_diff = 0.5 * np.sum(np.abs(true_memb_probs - memb_probs[:, perm])) # Assert that expected membership is less than 10% assert total_diff < 0.1 * np.sum(starcounts) for origin, best_comp in zip(origins, np.array(best_comps)[perm, ]): assert (isinstance(origin, SphereComponent) and isinstance(best_comp, SphereComponent)) o_pars = origin.get_pars() b_pars = best_comp.get_pars() logging.info("origin pars: {}".format(o_pars)) logging.info("best fit pars: {}".format(b_pars)) assert np.allclose(origin.get_mean(), best_comp.get_mean(), atol=5.) assert np.allclose(origin.get_sphere_dx(), best_comp.get_sphere_dx(), atol=2.) assert np.allclose(origin.get_sphere_dv(), best_comp.get_sphere_dv(), atol=2.) assert np.allclose(origin.get_age(), best_comp.get_age(), atol=1.)
def test_fit_many_comps(): """ Synthesise a file with negligible error, retrieve initial parameters Takes a while... maybe this belongs in integration unit_tests """ run_name = 'stationary' logging.info(60 * '-') logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-') logging.info(60 * '-') savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name) mkpath(savedir) data_filename = savedir + '{}_expectmax_{}_data.fits'.format( PY_VERS, run_name) log_filename = 'temp_data/{}_expectmax_{}/log.log'.format( PY_VERS, run_name) logging.basicConfig(level=logging.INFO, filemode='w', filename=log_filename) uniform_age = 1e-10 sphere_comp_pars = np.array([ # X, Y, Z, U, V, W, dX, dV, age, [-50, -50, -50, 0, 0, 0, 10., 5, uniform_age], [50, 50, 50, 0, 0, 0, 10., 5, uniform_age], ]) starcounts = [20, 50] ncomps = sphere_comp_pars.shape[0] # initialise z appropriately true_memb_probs = np.zeros((np.sum(starcounts), ncomps)) start = 0 for i in range(ncomps): true_memb_probs[start:start + starcounts[i], i] = 1.0 start += starcounts[i] # Initialise some random membership probablities # Normalising such that each row sums to 1 init_memb_probs = np.random.rand(np.sum(starcounts), ncomps) init_memb_probs = (init_memb_probs.T / init_memb_probs.sum(axis=1)).T synth_data = SynthData( pars=sphere_comp_pars, starcounts=starcounts, Components=SphereComponent, ) synth_data.synthesise_everything() tabletool.convert_table_astro2cart(synth_data.table, write_table=True, filename=data_filename) origins = [SphereComponent(pars) for pars in sphere_comp_pars] best_comps, med_and_spans, memb_probs = \ expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps, rdir=savedir, init_memb_probs=init_memb_probs, trace_orbit_func=dummy_trace_orbit_func, ignore_stable_comps=False) perm = expectmax.get_best_permutation(memb_probs, true_memb_probs) logging.info('Best permutation is: {}'.format(perm)) assert np.allclose(true_memb_probs, memb_probs[:, perm]) for origin, best_comp in zip(origins, np.array(best_comps)[perm, ]): assert (isinstance(origin, SphereComponent) and isinstance(best_comp, SphereComponent)) o_pars = origin.get_pars() b_pars = best_comp.get_pars() logging.info("origin pars: {}".format(o_pars)) logging.info("best fit pars: {}".format(b_pars)) assert np.allclose(origin.get_mean(), best_comp.get_mean(), atol=5.) assert np.allclose(origin.get_sphere_dx(), best_comp.get_sphere_dx(), atol=2.) assert np.allclose(origin.get_sphere_dv(), best_comp.get_sphere_dv(), atol=2.) assert np.allclose(origin.get_age(), best_comp.get_age(), atol=1.)
def test_fit_one_comp_with_background(): """ Synthesise a file with negligible error, retrieve initial parameters Takes a while... Parameters ---------- """ run_name = 'background' logging.info(60 * '-') logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-') logging.info(60 * '-') savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name) mkpath(savedir) data_filename = savedir + '{}_expectmax_{}_data.fits'.format( PY_VERS, run_name) log_filename = 'temp_data/{}_expectmax_{}/log.log'.format( PY_VERS, run_name) logging.basicConfig(level=logging.INFO, filemode='w', filename=log_filename) uniform_age = 1e-10 sphere_comp_pars = np.array([ # X, Y, Z, U, V, W, dX, dV, age, [0, 0, 0, 0, 0, 0, 10., 5, uniform_age], ]) starcount = 200 background_density = 1e-9 ncomps = sphere_comp_pars.shape[0] # true_memb_probs = np.zeros((starcount, ncomps)) # true_memb_probs[:,0] = 1. synth_data = SynthData( pars=sphere_comp_pars, starcounts=[starcount], Components=SphereComponent, background_density=background_density, ) synth_data.synthesise_everything() tabletool.convert_table_astro2cart(synth_data.table) background_count = len(synth_data.table) - starcount logging.info('Generated {} background stars'.format(background_count)) # insert background densities synth_data.table['background_log_overlap'] =\ len(synth_data.table) * [np.log(background_density)] synth_data.table.write(data_filename, overwrite=True) origins = [SphereComponent(pars) for pars in sphere_comp_pars] best_comps, med_and_spans, memb_probs = \ expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps, rdir=savedir, burnin=500, sampling_steps=5000, trace_orbit_func=dummy_trace_orbit_func, use_background=True, ignore_stable_comps=False, max_em_iterations=200) # return best_comps, med_and_spans, memb_probs # Check parameters are close assert np.allclose(sphere_comp_pars, best_comps[0].get_pars(), atol=1.5) # Check most assoc members are correctly classified recovery_count_threshold = 0.95 * starcount recovery_count_actual = np.sum(memb_probs[:starcount, 0] > 0.5) assert recovery_count_threshold < recovery_count_actual # Check most background stars are correctly classified # Number of bg stars classified as members should be less than 5% # of all background stars contamination_count_threshold = 0.05 * len(memb_probs[starcount:]) contamination_count_actual = np.sum(memb_probs[starcount:, 0] > 0.5) assert contamination_count_threshold > contamination_count_actual # Check reported membership probabilities are consistent with recovery # rate (within 5%) mean_membership_confidence = np.mean(memb_probs[:starcount, 0]) assert np.isclose(recovery_count_actual / starcount, mean_membership_confidence, atol=0.05)
# handle special case of one component if ncomps == 1: logging.info("******************************************") logging.info("********* FITTING 1 COMPONENT **********") logging.info("******************************************") run_dir = rdir + '{}/'.format(ncomps) mkpath(run_dir) try: new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy') new_meds = np.load(run_dir + 'final/final_med_errs.npy') new_z = np.load(run_dir + 'final/final_membership.npy') logging.info("Loaded from previous run") except IOError: new_groups, new_meds, new_z = \ em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool, bg_ln_ols=bg_ln_ols) # kill component here new_groups = np.array(new_groups) new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups, bg_ln_ols=bg_ln_ols) new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups, bg_ln_ols=bg_ln_ols, inc_posterior=True) new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike) # handle multiple components else: logging.info("******************************************") logging.info("********* FITTING {} COMPONENTS *********".\ format(ncomps)) logging.info("******************************************") best_fits = []
all_xyzuvw_init, origins = syn.synthesiseManyXYZUVW( group_pars, sphere=True, return_groups=True ) logging.info(" done") # xyzuvw_init, group =\ # syn.synthesiseXYZUVW(group_pars, sphere=True, return_group=True) # all_xyzuvw_init = np.vstack((all_xyzuvw_init, xyzuvw_init)) # # xyzuvw_now_perf = torb.traceManyOrbitXYZUVW(xyzuvw_init, group.age, # single_age=True) # all_xyzuvw_now_perf =\ # np.vstack((all_xyzuvw_now_perf, xyzuvw_now_perf)) # origins.append(group) logging.info("Saving synthetic data...") np.save(rdir+groups_savefile, origins) np.save(rdir+xyzuvw_perf_file, all_xyzuvw_now_perf) astro_table = chronostar.synthdata.measureXYZUVW(all_xyzuvw_now_perf, 1.0, savefile=rdir+astro_savefile) star_pars = cv.convertMeasurementsToCartesian( astro_table, savefile=rdir+xyzuvw_conv_savefile, ) em.fit_many_comps(star_pars, ngroups, origins=origins, rdir=rdir, pool=pool, #init_with_origin=True ) if using_mpi: pool.close()
if ncomps == 1: logging.info("******************************************") logging.info("********* FITTING 1 COMPONENT **********") logging.info("******************************************") run_dir = rdir + '{}/'.format(ncomps) mkpath(run_dir) try: new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy') new_meds = np.load(run_dir + 'final/final_med_errs.npy') new_z = np.load(run_dir + 'final/final_membership.npy') logging.info("Loaded from previous run") except IOError: new_groups, new_meds, new_z =\ em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool, bg_dens=BG_DENS, ) new_groups = np.array(new_groups) new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups, bg_ln_ols=bg_ln_ols) new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups, bg_ln_ols=bg_ln_ols, inc_posterior=True) new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike) # handle multiple components else: logging.info("******************************************") logging.info("********* FITTING {} COMPONENTS *********".\
final_cdir = run_dir + 'final/comp{}/'.format(i) chain = np.load(final_cdir + 'final_chain.npy') lnprob = np.load(final_cdir + 'final_lnprob.npy') npars = len(Component.PARAMETER_FORMAT) best_ix = np.argmax(lnprob) best_pars = chain.reshape(-1, npars)[best_ix] prev_comps[i] = Component(emcee_pars=best_pars) np.save(str(run_dir + 'final/' + final_comps_file), prev_comps) logging.info('Loaded from previous run') except IOError: prev_comps, prev_med_and_spans, prev_memb_probs = \ expectmax.fit_many_comps(data=data_dict, ncomps=ncomps, rdir=run_dir, trace_orbit_func=trace_orbit_func, burnin=config.advanced['burnin_steps'], sampling_steps=config.advanced['sampling_steps'], use_background=config.config[ 'include_background_distribution'], init_memb_probs=init_memb_probs, ) # Calculate global score of fit for comparison with future fits with different # component counts prev_lnlike = expectmax.get_overall_lnlikelihood( data_dict, prev_comps, # bg_ln_ols=bg_ln_ols, ) prev_lnpost = expectmax.get_overall_lnlikelihood( data_dict, prev_comps, # bg_ln_ols=bg_ln_ols,
def test_fit_one_comp_with_background(): """ Synthesise a file with negligible error, retrieve initial parameters Takes a while... maybe this belongs in integration unit_tests """ run_name = 'background' savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name) mkpath(savedir) data_filename = savedir + '{}_expectmax_{}_data.fits'.format(PY_VERS, run_name) # log_filename = 'temp_data/{}_expectmax_{}/log.log'.format(PY_VERS, # run_name) logging.basicConfig(level=logging.INFO, filemode='w', filename=log_filename) uniform_age = 1e-10 sphere_comp_pars = np.array([ # X, Y, Z, U, V, W, dX, dV, age, [ 0, 0, 0, 0, 0, 0, 10., 5, uniform_age], ]) starcount = 100 background_density = 1e-9 ncomps = sphere_comp_pars.shape[0] # true_memb_probs = np.zeros((starcount, ncomps)) # true_memb_probs[:,0] = 1. synth_data = SynthData(pars=sphere_comp_pars, starcounts=[starcount], Components=SphereComponent, background_density=background_density, ) synth_data.synthesise_everything() tabletool.convert_table_astro2cart(synth_data.table, write_table=True, filename=data_filename) background_count = len(synth_data.table) - starcount # insert background densities synth_data.table['background_log_overlap'] =\ len(synth_data.table) * [np.log(background_density)] origins = [SphereComponent(pars) for pars in sphere_comp_pars] best_comps, med_and_spans, memb_probs = \ expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps, rdir=savedir, trace_orbit_func=dummy_trace_orbit_func, use_background=True) return best_comps, med_and_spans, memb_probs # Check parameters are close assert np.allclose(sphere_comp_pars, best_comps[0].get_pars(), atol=1.) # Check most assoc members are correctly classified recovery_count_threshold = 0.95 * starcounts[0] recovery_count_actual = np.sum(np.round(memb_probs[:starcount,0])) assert recovery_count_threshold < recovery_count_actual # Check most background stars are correctly classified contamination_count_threshold = 0.05 * len(memb_probs[100:]) contamination_count_actual = np.sum(np.round(memb_probs[starcount:,0])) assert contamination_count_threshold < contamination_count_actual # Check reported membership probabilities are consistent with recovery # rate (within 5%) mean_membership_confidence = np.mean(memb_probs[:starcount,0]) assert np.isclose(recovery_count_actual/100., mean_membership_confidence, atol=0.05)
def test_fit_many_comps(): """ Synthesise a file with negligible error, retrieve initial parameters Takes a while... maybe this belongs in integration unit_tests """ run_name = 'stationary' savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name) mkpath(savedir) data_filename = savedir + '{}_expectmax_{}_data.fits'.format(PY_VERS, run_name) # log_filename = 'temp_data/{}_expectmax_{}/log.log'.format(PY_VERS, # run_name) logging.basicConfig(level=logging.INFO, filemode='w', filename=log_filename) uniform_age = 1e-10 sphere_comp_pars = np.array([ # X, Y, Z, U, V, W, dX, dV, age, [-50,-50,-50, 0, 0, 0, 10., 5, uniform_age], [ 50, 50, 50, 0, 0, 0, 10., 5, uniform_age], ]) starcounts = [200,200] ncomps = sphere_comp_pars.shape[0] # initialise z appropriately # start = 0 # for i in range(ngroups): # nstars_in_group = int(group_pars[i,-1]) # z[start:start+nstars_in_group,i] = 1.0 # start += nstars_in_group true_memb_probs = np.zeros((np.sum(starcounts), ncomps)) true_memb_probs[:200,0] = 1. true_memb_probs[200:,1] = 1. synth_data = SynthData(pars=sphere_comp_pars, starcounts=starcounts, Components=SphereComponent, ) synth_data.synthesise_everything() tabletool.convert_table_astro2cart(synth_data.table, write_table=True, filename=data_filename) origins = [SphereComponent(pars) for pars in sphere_comp_pars] best_comps, med_and_spans, memb_probs = \ expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps, rdir=savedir, trace_orbit_func=dummy_trace_orbit_func, ) # compare fit with input try: assert np.allclose(true_memb_probs, memb_probs) except AssertionError: # If not close, check if flipping component order fixes things memb_probs = memb_probs[:,::-1] best_comps = best_comps[::-1] assert np.allclose(true_memb_probs, memb_probs) for origin, best_comp in zip(origins, best_comps): assert (isinstance(origin, SphereComponent) and isinstance(best_comp, SphereComponent)) o_pars = origin.get_pars() b_pars = best_comp.get_pars() logging.info("origin pars: {}".format(o_pars)) logging.info("best fit pars: {}".format(b_pars)) assert np.allclose(origin.get_mean(), best_comp.get_mean(), atol=5.) assert np.allclose(origin.get_sphere_dx(), best_comp.get_sphere_dx(), atol=2.) assert np.allclose(origin.get_sphere_dv(), best_comp.get_sphere_dv(), atol=2.) assert np.allclose(origin.get_age(), best_comp.get_age(), atol=1.)
bp_cov = np.cov(star_means.T) bp_dx = np.sqrt(np.min([bp_cov[0,0], bp_cov[1,1], bp_cov[2,2]])) bp_dv = np.sqrt(np.min([bp_cov[3,3], bp_cov[4,4], bp_cov[5,5]])) bp_age = 0.5 bp_pars = np.hstack((bp_mean, bp_dx, bp_dv, bp_age, nstars)) bp_group = chronostar.component.Component(bp_pars) origins = [bp_group] #go through and compare overlap with groups with #background overlap # #bp_pars = np.array([ # 2.98170398e+01, 4.43573995e+01, 2.29251498e+01, -9.65731744e-01, # -3.42827894e+00, -3.99928052e-02 , 2.63084094e+00, 1.05302890e-01, # 1.59367119e+01, nstars #]) #bp_group = syn.Group(bp_pars) # # -------------------------------------------------------------------------- # Run fit # -------------------------------------------------------------------------- logging.info("Using data file {}".format(xyzuvw_file)) logging.info("Everything loaded, about to fit with {} components"\ .format(NGROUPS)) em.fit_many_comps(star_pars, NGROUPS, rdir=rdir, pool=pool, offset=True, bg_hist_file=bg_hist_file, origins=origins, init_with_origin=init_origin, ) if using_mpi: pool.close()
best_ix = np.argmax(lnprob) best_pars = chain.reshape(-1, npars)[best_ix] prev_comps[i] = Component(emcee_pars=best_pars) Component.store_raw_components( str(run_dir + 'final/' + final_comps_file), prev_comps) # np.save(str(run_dir+'final/'+final_comps_file), prev_comps) logging.info('Loaded from previous run') except IOError: prev_comps, prev_med_and_spans, prev_memb_probs = \ expectmax.fit_many_comps(data=data_dict, ncomps=ncomps, rdir=run_dir, trace_orbit_func=trace_orbit_func, burnin=config.advanced['burnin_steps'], sampling_steps=config.advanced['sampling_steps'], use_background=config.config[ 'include_background_distribution'], init_memb_probs=init_memb_probs, Component=Component, store_burnin_chains=store_burnin_chains, max_iters=MAX_ITERS, ) # Calculate global score of fit for comparison with future fits with different # component counts prev_lnlike = expectmax.get_overall_lnlikelihood( data_dict, prev_comps, # bg_ln_ols=bg_ln_ols, ) prev_lnpost = expectmax.get_overall_lnlikelihood( data_dict,
# handle special case of one component if ncomps == 1: logging.info("******************************************") logging.info("********* FITTING 1 COMPONENT **********") logging.info("******************************************") run_dir = rdir + '{}/'.format(ncomps) mkpath(run_dir) try: new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy') new_meds = np.load(run_dir + 'final/final_med_errs.npy') new_z = np.load(run_dir + 'final/final_membership.npy') logging.info("Loaded from previous run") except IOError: new_groups, new_meds, new_z = \ em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool, bg_ln_ols=bg_ln_ols) new_groups = np.array(new_groups) new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups, bg_ln_ols=bg_ln_ols) new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups, bg_ln_ols=bg_ln_ols, inc_posterior=True) new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike) # handle multiple components else: logging.info("******************************************") logging.info("********* FITTING {} COMPONENTS *********".\ format(ncomps)) logging.info("******************************************") best_fits = []
final_cdir = run_dir + 'final/comp{}/'.format(i) chain = np.load(final_cdir + 'final_chain.npy') lnprob = np.load(final_cdir + 'final_lnprob.npy') npars = len(Component.PARAMETER_FORMAT) best_ix = np.argmax(lnprob) best_pars = chain.reshape(-1,npars)[best_ix] prev_comps[i] = Component(emcee_pars=best_pars) np.save(str(run_dir+'final/'+final_comps_file), prev_comps) logging.info('Loaded from previous run') except IOError: prev_comps, prev_med_and_spans, prev_memb_probs = \ expectmax.fit_many_comps(data=data_dict, ncomps=ncomps, rdir=run_dir, trace_orbit_func=trace_orbit_func, burnin=config.advanced['burnin_steps'], sampling_steps=config.advanced['sampling_steps'], use_background=config.config[ 'include_background_distribution'], init_memb_probs=init_memb_probs, ) # Calculate global score of fit for comparison with future fits with different # component counts prev_lnlike = expectmax.get_overall_lnlikelihood(data_dict, prev_comps, # bg_ln_ols=bg_ln_ols, ) prev_lnpost = expectmax.get_overall_lnlikelihood(data_dict, prev_comps, # bg_ln_ols=bg_ln_ols, inc_posterior=True) prev_bic = expectmax.calc_bic(data_dict, ncomps, prev_lnlike)
init_z=init_z.T print(np.sum(init_z, axis=0)) ncomps=len(np.sum(init_z, axis=0))-1 print('ncomps: %d'%ncomps) print('%d components'%(len(np.sum(init_z, axis=0))-1)) print('init_z successful!! Yey') # -------------------------------------------------------------------------- # Perform one EM fit # -------------------------------------------------------------------------- logging.info("Using data file {}".format(xyzuvw_file)) logging.info("Everything loaded, about to fit with {} components"\ .format(ncomps)) final_groups, final_med_errs, z = em.fit_many_comps(star_pars, ncomps, bg_ln_ols=bg_ln_ols, rdir=rdir, pool=pool, init_memb_probs=init_z, ignore_dead_comps=True) # -------------------------------------------------------------------------- # Repeat EM fit but killing components with too few members # -------------------------------------------------------------------------- # e.g. maybe < 3 expected star count # take results from entire past EM fit # # init_origin = False # origins = None # if NGROUPS == 1: # init_origin = True # nstars = star_means.shape[0]
if ncomps == 1: logging.info("******************************************") logging.info("********* FITTING 1 COMPONENT **********") logging.info("******************************************") run_dir = rdir + '{}/'.format(ncomps) mkpath(run_dir) try: new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy') new_meds = np.load(run_dir + 'final/final_med_errs.npy') new_z = np.load(run_dir + 'final/final_membership.npy') logging.info("Loaded from previous run") except IOError: new_groups, new_meds, new_z =\ em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool, bg_dens=BG_DENS, ) new_groups = np.array(new_groups) new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups, bg_ln_ols=bg_ln_ols) new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups, bg_ln_ols=bg_ln_ols, inc_posterior=True) new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike) # handle multiple components else: logging.info("******************************************") logging.info("********* FITTING {} COMPONENTS *********".\ format(ncomps)) logging.info("******************************************")
# # xyzuvw_now_perf = torb.traceManyOrbitXYZUVW(xyzuvw_init, group.age, # single_age=True) # all_xyzuvw_now_perf =\ # np.vstack((all_xyzuvw_now_perf, xyzuvw_now_perf)) # origins.append(group) logging.info("Saving synthetic data...") np.save(rdir + groups_savefile, origins) np.save(rdir + xyzuvw_perf_file, all_xyzuvw_now_perf) astro_table = chronostar.synthdata.measureXYZUVW(all_xyzuvw_now_perf, 1.0, savefile=rdir + astro_savefile) star_pars = cv.convertMeasurementsToCartesian( astro_table, savefile=rdir + xyzuvw_conv_savefile, ) em.fit_many_comps( star_pars, ngroups, origins=origins, rdir=rdir, pool=pool, #init_with_origin=True ) if using_mpi: pool.close()