def test_train_eval(): dataset_path = generate_dataset_file(5, 7) with tempdir() as dirname: models_dir = os.path.join(dirname, 'models') result_path = os.path.join(dirname, 'tune_clusters.pkz') validate.train(dataset_path, PARAM_CSV, models_dir) validate.eval(dataset_path, PARAM_CSV, models_dir, result_path) assert os.path.exists(result_path)
def test_train(): dataset_path = generate_dataset_file(10, 10) with tempdir() as dirname: ensemble_path = os.path.join(dirname, 'ensemble.pkz') train( dataset_path, ensemble_path, model_ensemble_size='3', learning_init_epochs='3') assert os.path.exists(ensemble_path)
def test_guess_schema(): with tempdir() as dirname: types_csv_out = os.path.join(dirname, 'types.csv') values_csv_out = os.path.join(dirname, 'values.csv') guess_schema(DATA_CSV, types_csv_out, values_csv_out) expected_types = open(TYPES_CSV).read() expected_values = open(VALUES_CSV).read() actual_types = open(types_csv_out).read() actual_values = open(values_csv_out).read() assert actual_types == expected_types assert actual_values == expected_values
def serve(rows=100, cols=10, cats=4, tool='timers'): """Profile TreeCatServer on a random dataset. Available tools: timers, time, snakeviz, line_profiler, pdb """ from treecat.generate import generate_model_file config = make_config() model_path = generate_model_file(rows, cols, cats) with tempdir() as dirname: config_path = os.path.join(dirname, 'config.pkz') pickle_dump(config, config_path) cmd = [FILE, 'serve_files', model_path, config_path, str(rows)] run_with_tool(cmd, tool, dirname)
def train(rows=100, cols=10, epochs=5, clusters=32, parallel=False, tool='timers'): """Profile TreeCatTrainer on a random dataset. Available tools: timers, time, snakeviz, line_profiler, pdb """ from treecat.generate import generate_dataset_file config = make_config(learning_init_epochs=epochs, model_num_clusters=clusters, model_ensemble_size=1, learning_parallel=parallel) dataset_path = generate_dataset_file(rows, cols) with tempdir() as dirname: config_path = os.path.join(dirname, 'config.pkz') pickle_dump(config, config_path) cmd = [FILE, 'train_files', dataset_path, config_path] run_with_tool(cmd, tool, dirname)
def eval(rows=100, cols=10, cats=4, tool='timers'): """Profile treecat.validate.eval on a random dataset. Available tools: timers, time, snakeviz, line_profiler, pdb """ from treecat.generate import generate_dataset_file from treecat.validate import train dataset_path = generate_dataset_file(rows, cols) validate_py = os.path.join(os.path.dirname(FILE), 'validate.py') with tempdir() as dirname: param_csv_path = os.path.join(dirname, 'param.csv') with open(param_csv_path, 'w') as f: f.write('learning_init_epochs\n2') train(dataset_path, param_csv_path, dirname, learning_init_epochs=2) cmd = [ validate_py, 'eval', dataset_path, param_csv_path, dirname, os.path.join(dirname, 'tuning.pkz'), 'learning_init_epochs=2', ] run_with_tool(cmd, tool, dirname)
def test_e2e(model_type): with tempdir() as dirname: data_csv = os.path.join(TESTDATA, 'tiny_data.csv') config = TINY_CONFIG.copy() print('Guess schema.') types_csv = os.path.join(dirname, 'types.csv') values_csv = os.path.join(dirname, 'values.csv') guess_schema(data_csv, types_csv, values_csv) print('Load schema') groups_csv = os.path.join(TESTDATA, 'tiny_groups.csv') schema = load_schema(types_csv, values_csv, groups_csv) ragged_index = schema['ragged_index'] tree_prior = schema['tree_prior'] print('Load data') data = load_data(schema, data_csv) feature_types = [TY_MULTINOMIAL] * len(schema['feature_names']) table = Table(feature_types, ragged_index, data) dataset = { 'schema': schema, 'table': table, } print('Train model') if model_type == 'single': model = train_model(table, tree_prior, config) elif model_type == 'ensemble': model = train_ensemble(table, tree_prior, config) else: raise ValueError(model_type) print('Serve model') server = serve_model(dataset, model) print('Query model') evidence = {'genre': 'drama'} server.logprob([evidence]) samples = server.sample(100) server.logprob(samples) samples = server.sample(100, evidence) server.logprob(samples) try: median = server.median([evidence]) server.logprob(median) except NotImplementedError: warn('{} median not implemented'.format(model_type)) pass try: mode = server.mode([evidence]) server.logprob(mode) except NotImplementedError: warn('{} mode not implemented'.format(model_type)) pass print('Examine latent structure') server.feature_density() server.observed_perplexity() server.latent_perplexity() server.latent_correlation() server.estimate_tree() server.sample_tree(10) print('Plotting latent structure') plot_circular(server)
def test_import_data(): with tempdir() as dirname: dataset_out = os.path.join(dirname, 'dataset.pkz') assert not os.path.exists(dataset_out) import_data(DATA_CSV, TYPES_CSV, VALUES_CSV, GROUPS_CSV, dataset_out) assert os.path.exists(dataset_out)
def test_pickle(data, ext): with tempdir() as dirname: filename = os.path.join(dirname, 'test.{}'.format(ext)) pickle_dump(data, filename) actual = pickle_load(filename) assert_equal(actual, data)