def test_multiple_algorithms_engine(): engine = Engine({ 'datasource': TestEmptyDataSource, 'algorithms': { 'algo1': { 'class': TestMultiAlgorithm, 'params': { 'p': 'A' }, }, 'algo2': { 'class': TestMultiAlgorithm, 'params': { 'p': 'B' }, }, }, 'serving': TestIdentityServing, }) models = { 'algo1': object(), 'algo2': object(), } nt.assert_equals(engine.predict(models, None), { 'algo1': 'A', 'algo2': 'B', })
def test_setup(): make_test_data() engine = Engine({ 'datasource': { 'class': TestDataSource, 'params': { 'csv': test_data_file, }, }, 'preparator': TestPreparator, 'algorithm': { 'class': TestAlgorithm, 'params': { 'model.csv': '~/.tidml/tests/model2.csv', # custom }, }, }) engine.train()
def main(): engine = Engine({'config': 'examples/dase/house_prices/config.yaml'}) engine.train() models = engine.load_models() engine.predict(models, [0.07, 0.99, 0.0, 0.51, 0.69, 0.77, 0.77, 0.75, 0.44])
def create(): return Engine({ "datasource": MyDataSource, "algorithm": { "class": MyAlgorithm, "params": { "model.pickle": "~/.tidml/hello_world/model.pkl" } } })
def test_insane_datasource(): engine = Engine({ 'datasource': { 'class': TestDataSource, 'params': { 'csv': test_data_file, 'insane': True, }, }, }) nt.assert_raises_regexp(RuntimeError, 'training_data insane!', engine.train)
def test_insane_algorithm(): engine = Engine({ 'datasource': { 'class': TestDataSource, 'params': { 'csv': test_data_file, }, }, 'preparator': TestPreparator, 'algorithm': { 'class': TestAlgorithm, 'params': { 'model.csv': '~/.tidml/tests/model2.csv', # custom 'insane': True, }, }, }) nt.assert_raises_regexp(RuntimeError, 'model insane!', engine.train)
def test_simple_engine(): engine = Engine({ 'datasource': { 'class': TestDataSource, 'params': { 'csv': test_data_file, }, }, 'algorithm': { 'class': TestSimpleAlgorithm, 'params': { 'model.pickle': '~/.tidml/tests/model.pkl', # default built-in }, }, }) engine.train() models = engine.load_models() prediction = engine.predict(models, 3) nt.assert_equals(prediction, 6)
def create(): return Engine({'config': 'examples/dase/gnb_classifier/config.yaml'})
def create(): return Engine({'config': 'examples/dase/regression/config.yaml'})