예제 #1
0
def test_train_eval():
    dataset_path = generate_dataset_file(5, 7)
    with tempdir() as dirname:
        models_dir = os.path.join(dirname, 'models')
        result_path = os.path.join(dirname, 'tune_clusters.pkz')
        validate.train(dataset_path, PARAM_CSV, models_dir)
        validate.eval(dataset_path, PARAM_CSV, models_dir, result_path)
        assert os.path.exists(result_path)
예제 #2
0
def test_train():
    dataset_path = generate_dataset_file(10, 10)
    with tempdir() as dirname:
        ensemble_path = os.path.join(dirname, 'ensemble.pkz')
        train(
            dataset_path,
            ensemble_path,
            model_ensemble_size='3',
            learning_init_epochs='3')
        assert os.path.exists(ensemble_path)
예제 #3
0
def test_guess_schema():
    with tempdir() as dirname:
        types_csv_out = os.path.join(dirname, 'types.csv')
        values_csv_out = os.path.join(dirname, 'values.csv')
        guess_schema(DATA_CSV, types_csv_out, values_csv_out)
        expected_types = open(TYPES_CSV).read()
        expected_values = open(VALUES_CSV).read()
        actual_types = open(types_csv_out).read()
        actual_values = open(values_csv_out).read()
        assert actual_types == expected_types
        assert actual_values == expected_values
예제 #4
0
def serve(rows=100, cols=10, cats=4, tool='timers'):
    """Profile TreeCatServer on a random dataset.
    Available tools: timers, time, snakeviz, line_profiler, pdb
    """
    from treecat.generate import generate_model_file
    config = make_config()
    model_path = generate_model_file(rows, cols, cats)
    with tempdir() as dirname:
        config_path = os.path.join(dirname, 'config.pkz')
        pickle_dump(config, config_path)
        cmd = [FILE, 'serve_files', model_path, config_path, str(rows)]
        run_with_tool(cmd, tool, dirname)
예제 #5
0
def train(rows=100,
          cols=10,
          epochs=5,
          clusters=32,
          parallel=False,
          tool='timers'):
    """Profile TreeCatTrainer on a random dataset.
    Available tools: timers, time, snakeviz, line_profiler, pdb
    """
    from treecat.generate import generate_dataset_file
    config = make_config(learning_init_epochs=epochs,
                         model_num_clusters=clusters,
                         model_ensemble_size=1,
                         learning_parallel=parallel)
    dataset_path = generate_dataset_file(rows, cols)
    with tempdir() as dirname:
        config_path = os.path.join(dirname, 'config.pkz')
        pickle_dump(config, config_path)
        cmd = [FILE, 'train_files', dataset_path, config_path]
        run_with_tool(cmd, tool, dirname)
예제 #6
0
def eval(rows=100, cols=10, cats=4, tool='timers'):
    """Profile treecat.validate.eval on a random dataset.
    Available tools: timers, time, snakeviz, line_profiler, pdb
    """
    from treecat.generate import generate_dataset_file
    from treecat.validate import train
    dataset_path = generate_dataset_file(rows, cols)
    validate_py = os.path.join(os.path.dirname(FILE), 'validate.py')
    with tempdir() as dirname:
        param_csv_path = os.path.join(dirname, 'param.csv')
        with open(param_csv_path, 'w') as f:
            f.write('learning_init_epochs\n2')
        train(dataset_path, param_csv_path, dirname, learning_init_epochs=2)
        cmd = [
            validate_py,
            'eval',
            dataset_path,
            param_csv_path,
            dirname,
            os.path.join(dirname, 'tuning.pkz'),
            'learning_init_epochs=2',
        ]
        run_with_tool(cmd, tool, dirname)
예제 #7
0
def test_e2e(model_type):
    with tempdir() as dirname:
        data_csv = os.path.join(TESTDATA, 'tiny_data.csv')
        config = TINY_CONFIG.copy()

        print('Guess schema.')
        types_csv = os.path.join(dirname, 'types.csv')
        values_csv = os.path.join(dirname, 'values.csv')
        guess_schema(data_csv, types_csv, values_csv)

        print('Load schema')
        groups_csv = os.path.join(TESTDATA, 'tiny_groups.csv')
        schema = load_schema(types_csv, values_csv, groups_csv)
        ragged_index = schema['ragged_index']
        tree_prior = schema['tree_prior']

        print('Load data')
        data = load_data(schema, data_csv)
        feature_types = [TY_MULTINOMIAL] * len(schema['feature_names'])
        table = Table(feature_types, ragged_index, data)
        dataset = {
            'schema': schema,
            'table': table,
        }

        print('Train model')
        if model_type == 'single':
            model = train_model(table, tree_prior, config)
        elif model_type == 'ensemble':
            model = train_ensemble(table, tree_prior, config)
        else:
            raise ValueError(model_type)

        print('Serve model')
        server = serve_model(dataset, model)

        print('Query model')
        evidence = {'genre': 'drama'}
        server.logprob([evidence])
        samples = server.sample(100)
        server.logprob(samples)
        samples = server.sample(100, evidence)
        server.logprob(samples)
        try:
            median = server.median([evidence])
            server.logprob(median)
        except NotImplementedError:
            warn('{} median not implemented'.format(model_type))
            pass
        try:
            mode = server.mode([evidence])
            server.logprob(mode)
        except NotImplementedError:
            warn('{} mode not implemented'.format(model_type))
            pass

        print('Examine latent structure')
        server.feature_density()
        server.observed_perplexity()
        server.latent_perplexity()
        server.latent_correlation()
        server.estimate_tree()
        server.sample_tree(10)

        print('Plotting latent structure')
        plot_circular(server)
예제 #8
0
def test_import_data():
    with tempdir() as dirname:
        dataset_out = os.path.join(dirname, 'dataset.pkz')
        assert not os.path.exists(dataset_out)
        import_data(DATA_CSV, TYPES_CSV, VALUES_CSV, GROUPS_CSV, dataset_out)
        assert os.path.exists(dataset_out)
예제 #9
0
def test_pickle(data, ext):
    with tempdir() as dirname:
        filename = os.path.join(dirname, 'test.{}'.format(ext))
        pickle_dump(data, filename)
        actual = pickle_load(filename)
        assert_equal(actual, data)