def test_fit_stability_mixed_comps(): """ Have a fit with some iterations that have a mix of stable and unstable comps. TODO: Maybe give 2 similar comps tiny age but overlapping origins """ run_name = 'mixed_stability' logging.info(60 * '-') logging.info(15 * '-' + '{:^30}'.format('TEST: ' + run_name) + 15 * '-') logging.info(60 * '-') savedir = 'temp_data/{}_expectmax_{}/'.format(PY_VERS, run_name) mkpath(savedir) data_filename = savedir + '{}_expectmax_{}_data.fits'.format( PY_VERS, run_name) log_filename = 'temp_data/{}_expectmax_{}/log.log'.format( PY_VERS, run_name) logging.basicConfig(level=logging.INFO, filemode='w', filename=log_filename) shared_cd_mean = np.zeros(6) tiny_age = 0.1 medium_age = 10. # origin_1 = traceorbit.trace_cartesian_orbit(shared_cd_mean, times=-medium_age) # origin_2 = traceorbit.trace_cartesian_orbit(shared_cd_mean, times=-2*medium_age) # # cd_mean_3 = np.array([-200,200,0,0,50,0.]) # origin_3 = traceorbit.trace_cartesian_orbit(cd_mean_3, times=-tiny_age) # # sphere_comp_pars = np.array([ # # X, Y, Z, U, V, W, dX, dV, age, # np.hstack((origin_1, 10., 5., medium_age)), # Next two comps share a current day origin # np.hstack((origin_2, 10., 5., 2*medium_age)), # so hopefully will need several iterations to\ # # disentangle # np.hstack((origin_3, 10., 5., tiny_age)), # a distinct comp that is stable quickly # ]) uniform_age = 1e-10 sphere_comp_pars = np.array([ # X, Y, Z, U, V, W, dX, dV, age, [50, 0, 0, 0, 50, 0, 10., 5, uniform_age], # Very distant (and stable) comp [0, -20, 0, 0, -5, 0, 10., 5, uniform_age], # Overlapping comp 1 [0, 20, 0, 0, 5, 0, 10., 5, uniform_age], # Overlapping comp 2 ]) starcounts = [50, 100, 200] ncomps = sphere_comp_pars.shape[0] # initialise z appropriately true_memb_probs = np.zeros((np.sum(starcounts), ncomps)) start = 0 for i in range(ncomps): true_memb_probs[start:start + starcounts[i], i] = 1.0 start += starcounts[i] # Initialise some random membership probablities # which will serve as our starting guess init_memb_probs = np.random.rand(np.sum(starcounts), ncomps) # To aid a component in quickly becoming stable, initialse the memberships # correclty for stars belonging to this component init_memb_probs[:starcounts[0]] = 0. init_memb_probs[:starcounts[0], 0] = 1. init_memb_probs[starcounts[0]:, 0] = 0. # Normalising such that each row sums to 1 init_memb_probs = (init_memb_probs.T / init_memb_probs.sum(axis=1)).T synth_data = SynthData( pars=sphere_comp_pars, starcounts=starcounts, Components=SphereComponent, ) synth_data.synthesise_everything() tabletool.convert_table_astro2cart(synth_data.table, write_table=True, filename=data_filename) origins = [SphereComponent(pars) for pars in sphere_comp_pars] SphereComponent.store_raw_components(savedir + 'origins.npy', origins) best_comps, med_and_spans, memb_probs = \ expectmax.fit_many_comps(data=synth_data.table, ncomps=ncomps, rdir=savedir, init_memb_probs=init_memb_probs, trace_orbit_func=dummy_trace_orbit_func, ignore_stable_comps=True) perm = expectmax.get_best_permutation(memb_probs, true_memb_probs) logging.info('Best permutation is: {}'.format(perm)) # Calculate the membership difference, we divide by 2 since # incorrectly allocated stars are double counted total_diff = 0.5 * np.sum(np.abs(true_memb_probs - memb_probs[:, perm])) # Assert that expected membership is less than 10% assert total_diff < 0.1 * np.sum(starcounts) for origin, best_comp in zip(origins, np.array(best_comps)[perm, ]): assert (isinstance(origin, SphereComponent) and isinstance(best_comp, SphereComponent)) o_pars = origin.get_pars() b_pars = best_comp.get_pars() logging.info("origin pars: {}".format(o_pars)) logging.info("best fit pars: {}".format(b_pars)) assert np.allclose(origin.get_mean(), best_comp.get_mean(), atol=5.) assert np.allclose(origin.get_sphere_dx(), best_comp.get_sphere_dx(), atol=2.) assert np.allclose(origin.get_sphere_dv(), best_comp.get_sphere_dv(), atol=2.) assert np.allclose(origin.get_age(), best_comp.get_age(), atol=1.)
str(run_dir + 'final/' + final_comps_file)) # Final comps are there, they just can't be read by current module # so quickly fit them based on fixed prev membership probabilities except AttributeError: logging.info('Component class has been modified, reconstructing ' 'from chain') prev_comps = ncomps * [None] for i in range(ncomps): final_cdir = run_dir + 'final/comp{}/'.format(i) chain = np.load(final_cdir + 'final_chain.npy') lnprob = np.load(final_cdir + 'final_lnprob.npy') npars = len(Component.PARAMETER_FORMAT) best_ix = np.argmax(lnprob) best_pars = chain.reshape(-1, npars)[best_ix] prev_comps[i] = Component(emcee_pars=best_pars) Component.store_raw_components( str(run_dir + 'final/' + final_comps_file), prev_comps) # np.save(str(run_dir+'final/'+final_comps_file), prev_comps) logging.info('Loaded from previous run') except IOError: prev_comps, prev_med_and_spans, prev_memb_probs = \ expectmax.fit_many_comps(data=data_dict, ncomps=ncomps, rdir=run_dir, trace_orbit_func=trace_orbit_func, burnin=config.advanced['burnin_steps'], sampling_steps=config.advanced['sampling_steps'], use_background=config.config[ 'include_background_distribution'], init_memb_probs=init_memb_probs, Component=Component, store_burnin_chains=store_burnin_chains, max_iters=MAX_ITERS,