def test_parallel_run_errors(): with experiment_testing_context(): @experiment_function def my_error_causing_test(a=1): raise Exception('nononono') my_error_causing_test.add_variant(a=2) variants = my_error_causing_test.get_all_variants() assert len(variants) == 2 run_multiple_experiments(variants, parallel=True, raise_exceptions=False) with pytest.raises(Exception) as err: run_multiple_experiments(variants, parallel=True, raise_exceptions=True) print( "^^^ Dont't worry, the above is not actually an error, we were just asserting that we caught the error." ) assert str(err.value) == 'nononono'
def test_invalid_arg_detection(): """ Check that we notice when an experiment is redefined with new args. """ with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 7]}): return a + 1 rec = my_experiment_gfdsbhtds.run() assert rec.args_valid() clear_all_experiments() @experiment_function def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 7]}): return a + 1 assert rec.args_valid() # Assert that the args still match clear_all_experiments() @experiment_function def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={ 'a': 5, 'b': [6, 8] }): # CHANGE IN ARGS! return a + 1 assert not rec.args_valid()
def test_invalid_arg_text(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_invalid_arg_test(a=1, b={'c': 3, 'd': 4}): return a + b['c'] + b['d'] record = my_invalid_arg_test.run() assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>' clear_all_experiments() @experiment_function def my_invalid_arg_test(a=2, b={'c': 3, 'd': 4}): return a + b['c'] + b['d'] assert get_record_invalid_arg_string( record, recursive=True) == 'Change: {a:1}->{a:2}' clear_all_experiments() @experiment_function def my_invalid_arg_test(a=2, b={'c': 3, 'd': 2}): return a + b['c'] + b['d'] assert get_record_invalid_arg_string( record, recursive=True) == "Change: {a:1,b['d']:4}->{a:2,b['d']:2}"
def test_get_latest(): with experiment_testing_context(): record_1 = experiment_test_function.run() time.sleep(0.01) record_2 = experiment_test_function.run() identifier = load_experiment('experiment_test_function').get_latest_record().get_id() assert identifier == record_2.get_id()
def test_invalid_arg_text_when_object_arg(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_unhashable_arg_test(a=MyArgumentObject(a=3)): return a.a+2 record = my_unhashable_arg_test.run() assert record.get_result() == 5 assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>' # --------------------- clear_all_experiments() @experiment_function def my_unhashable_arg_test(a=MyArgumentObject(a=3)): return a.a+2 assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>' # --------------------- clear_all_experiments() @experiment_function def my_unhashable_arg_test(a=MyArgumentObject(a=4)): return a.a+2 assert get_record_invalid_arg_string(record, recursive=True) == 'Change: a.a:3->4'
def test_experiments_function_additions(): with experiment_testing_context(): for rec in my_xxxyyy_test_experiment.get_variant_records(flat=True): rec.delete() r1=my_xxxyyy_test_experiment.run() r2=my_xxxyyy_test_experiment.get_variant('a2').run() with pytest.raises(Exception): my_xxxyyy_test_experiment.get_variant(b=17).run() r3 = my_xxxyyy_test_experiment.get_variant(b=17).get_latest_record() assert r1.get_log() == 'xxx\n' assert r2.get_log() == 'yyy\n' assert get_oneline_result_string(my_xxxyyy_test_experiment.get_latest_record()) == '3bbb' assert get_oneline_result_string(my_xxxyyy_test_experiment.get_variant('a2').get_latest_record()) == '4bbb' assert get_oneline_result_string(my_xxxyyy_test_experiment.get_variant(b=17).get_latest_record()) == '<No result has been saved>' with CaptureStdOut() as cap: my_xxxyyy_test_experiment.show(my_xxxyyy_test_experiment.get_latest_record()) assert cap.read() == '3aaa\n' with CaptureStdOut() as cap: my_xxxyyy_test_experiment.compare([r1, r2]) assert cap.read() == 'my_xxxyyy_test_experiment: 3, my_xxxyyy_test_experiment.a2: 4\n' print('='*100+'\n ARGTABLE \n'+'='*100) print_experiment_record_argtable([r1, r2, r3]) print('='*100+'\n SHOW \n'+'='*100) compare_experiment_records([r1, r2, r3])
def test_experiment_function_ui(): with experiment_testing_context(): for existing_record in my_xxxyyy_test_experiment.get_variant_records(flat=True): existing_record.delete() assert len(my_xxxyyy_test_experiment.get_variant_records(flat=True)) == 0 my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='run all', close_after=True) assert len(my_xxxyyy_test_experiment.get_variant_records()) == 3 import time time.sleep(0.1) with assert_things_are_printed(min_len=1200, things=['Common Args', 'Different Args', 'Result', 'a=1, b=2', 'a=2, b=2', 'a=1, b=17']): my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='argtable all', close_after=True) with assert_things_are_printed(min_len=600, things=['my_xxxyyy_test_experiment: 3', 'my_xxxyyy_test_experiment.a2: 4']): my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='compare all -r', close_after=True) with assert_things_are_printed(min_len=1300, things=['Result', 'Logs', 'Ran Succesfully', 'Traceback']): my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='show all -o', close_after=True) with assert_things_are_printed(min_len=1250, things=['3aaa', '4aaa']): my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='show all -r', close_after=True) with assert_things_are_printed(min_len=1700, things=['Result', 'Logs', 'Ran Succesfully']): my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='show 0 -o', close_after=True)
def test_invalid_arg_text_when_object_arg(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_unhashable_arg_test(a=MyArgumentObject(a=3)): return a.a + 2 record = my_unhashable_arg_test.run() assert record.get_result() == 5 assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>' # --------------------- clear_all_experiments() @experiment_function def my_unhashable_arg_test(a=MyArgumentObject(a=3)): return a.a + 2 assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>' # --------------------- clear_all_experiments() @experiment_function def my_unhashable_arg_test(a=MyArgumentObject(a=4)): return a.a + 2 assert get_record_invalid_arg_string( record, recursive=True) == 'Change: {a.a:3}->{a.a:4}'
def test_invalid_arg_detection(): """ Check that we notice when an experiment is redefined with new args. """ with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 7]}): return a+1 rec = my_experiment_gfdsbhtds.run() assert rec.args_valid() clear_all_experiments() @experiment_function def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 7]}): return a+1 assert rec.args_valid() # Assert that the args still match clear_all_experiments() @experiment_function def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 8]}): # CHANGE IN ARGS! return a+1 assert not rec.args_valid()
def test_config_variant(): with experiment_testing_context(new_experiment_lib=True): @experiment_root def demo_smooth_out_signal(smoother, seed=1234): signal = np.sin(np.linspace( 0, 10, 100)) + 0.1 * np.random.RandomState(seed).randn(100) y = np.array([smoother(xt) for xt in signal]) return y X = demo_smooth_out_signal.add_config_variant( 'exp_smooth', smoother=lambda decay=0.1: _ExponentialMovingAverageForTestingPurposes(decay)) answer = X.run().get_result() assert answer.shape == (100, ) assert np.array_equal( X.run().get_result(), answer) # Check that we're definitely making a new one each time # Make sure we can still configure new args X2 = X.add_variant(seed=1235) answer2 = X2.run().get_result() assert answer2.shape == (100, ) assert not np.array_equal(answer, answer2) # Just check for no bugs in UI demo_smooth_out_signal.browse(command='q')
def test_get_latest(): with experiment_testing_context(): record_1 = experiment_test_function.run() time.sleep(0.01) record_2 = experiment_test_function.run() identifier = load_experiment( 'experiment_test_function').get_latest_record().get_id() assert identifier == record_2.get_id()
def test_experiments_play_well_with_debug(): with experiment_testing_context(): @experiment_function def my_simple_test(): plt.show._needmain=False # pyplot does this internally whenever a breakpoint is reached. return 1 my_simple_test.run()
def test_experiments_play_well_with_debug(): with experiment_testing_context(): @experiment_function def my_simple_test(): plt.show._needmain = False # pyplot does this internally whenever a breakpoint is reached. return 1 my_simple_test.run()
def test_figure_saving(show_them = False): with experiment_testing_context(): record = experiment_test_function.run() plt.close('all') # Close all figures figs = record.load_figures() assert len(figs)==2 if show_them: plt.show()
def test_get_latest_identifier(): with experiment_testing_context(): exp_rec = experiment_test_function.run() print(get_experiment_info('experiment_test_function')) assert_experiment_record_is_correct(exp_rec) last_experiment_identifier = load_experiment('experiment_test_function').get_latest_record().get_id() assert last_experiment_identifier is not None, 'Experiment was run, this should not be none' same_exp_rec = load_experiment_record(last_experiment_identifier) assert_experiment_record_is_correct(same_exp_rec)
def test_figure_saving(show_them=False): with experiment_testing_context(): record = experiment_test_function.run() plt.close('all') # Close all figures figs = record.load_figures() assert len(figs) == 2 if show_them: plt.show()
def test_unpicklable_args(): with experiment_testing_context(new_experiment_lib=True): @experiment_root def my_parametrizable_exp(f, x): return f(x) X = my_parametrizable_exp.add_variant(f=(lambda x: 2 * x), x=3) X.run() X.browse(command='q')
def test_unpicklable_args(): with experiment_testing_context(new_experiment_lib=True): @experiment_root def my_parametrizable_exp(f, x): return f(x) X = my_parametrizable_exp.add_variant(f = (lambda x: 2*x), x=3) X.run() X.browse(command='q')
def test_get_latest_identifier(): with experiment_testing_context(): exp_rec = experiment_test_function.run() print(get_experiment_info('experiment_test_function')) assert_experiment_record_is_correct(exp_rec) last_experiment_identifier = load_experiment( 'experiment_test_function').get_latest_record().get_id() assert last_experiment_identifier is not None, 'Experiment was run, this should not be none' same_exp_rec = load_experiment_record(last_experiment_identifier) assert_experiment_record_is_correct(same_exp_rec)
def test_run_multiple_experiments(): with experiment_testing_context(): experiments = my_api_test.get_all_variants() assert len(experiments) == 3 records = run_multiple_experiments(experiments) assert [record.get_result() for record in records] == [3, 4, 6] records = run_multiple_experiments(experiments, parallel=True) assert [record.get_result() for record in records] == [3, 4, 6]
def test_experiment_with(): """ DEPRECATED INTERFACE This syntax uses the record_experiment function directly instead of hiding it. """ with experiment_testing_context(): delete_experiment_with_id('test_exp') with record_experiment(identifier = 'test_exp', print_to_console=True) as exp_rec: experiment_test_function() assert_experiment_record_is_correct(exp_rec, show_figures=False)
def test_run_multiple_experiments(): with experiment_testing_context(): experiments = my_api_test.get_all_variants() assert len(experiments)==3 records = run_multiple_experiments(experiments) assert [record.get_result() for record in records] == [3, 4, 6] records = run_multiple_experiments(experiments, parallel=True) assert [record.get_result() for record in records] == [3, 4, 6]
def test_experiment_corrupt_detection(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_experiment_bfdsssdvs(a=1): return a+2 r = my_experiment_bfdsssdvs.run() assert r.get_status() == ExpStatusOptions.FINISHED os.remove(os.path.join(r.get_dir(), 'info.pkl')) # Manually remove the info file r = my_experiment_bfdsssdvs.get_latest_record() assert r.get_status() == ExpStatusOptions.CORRUPT
def test_experiment_api(try_browse=False): with experiment_testing_context(): my_api_test.get_variant('a2b2').run() record = my_api_test.get_variant('a2b2').get_latest_record() assert record.get_log() == 'aaa\n' assert record.get_result() == 4 assert record.get_args() == OrderedDict([('a', 2), ('b', 2)]) assert record.get_status() == ExpStatusOptions.FINISHED if try_browse: my_api_test.browse()
def test_experiment_api(try_browse=False): with experiment_testing_context(): my_api_test.get_variant('a2b2').run() record = my_api_test.get_variant('a2b2').get_latest_record() assert record.get_log() == 'aaa\n' assert record.get_result() == 4 assert record.get_args() == OrderedDict([('a', 2), ('b', 2)]) assert record.get_status() == ExpStatusOptions.FINISHED if try_browse: my_api_test.browse()
def test_experiment_with(): """ DEPRECATED INTERFACE This syntax uses the record_experiment function directly instead of hiding it. """ with experiment_testing_context(): delete_experiment_with_id('test_exp') with record_experiment(identifier='test_exp', print_to_console=True) as exp_rec: experiment_test_function() assert_experiment_record_is_correct(exp_rec, show_figures=False)
def test_experiment_corrupt_detection(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_experiment_bfdsssdvs(a=1): return a + 2 r = my_experiment_bfdsssdvs.run() assert r.get_status() == ExpStatusOptions.FINISHED os.remove(os.path.join(r.get_dir(), 'info.pkl')) # Manually remove the info file r = my_experiment_bfdsssdvs.get_latest_record() assert r.get_status() == ExpStatusOptions.CORRUPT
def test_start_experiment(): """ DEPRECATED INTERFACE An alternative syntax to the with statement - less tidy but possibly better for notebooks and such because it avoids you having to indent all code in the experiment. """ delete_experiment_with_id('start_stop_test') with experiment_testing_context(): record = start_experiment('start_stop_test') experiment_test_function() end_current_experiment() assert_experiment_record_is_correct(record, show_figures=False)
def test_start_experiment(): """ DEPRECATED INTERFACE An alternative syntax to the with statement - less tidy but possibly better for notebooks and such because it avoids you having to indent all code in the experiment. """ delete_experiment_with_id('start_stop_test') with experiment_testing_context(): record = start_experiment('start_stop_test') experiment_test_function() end_current_experiment() assert_experiment_record_is_correct(record, show_figures=False)
def test_simple_experiment_show(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_simdfsfdsgfs(a=1): print('xxxxx') print('yyyyy') return a+2 rec = my_simdfsfdsgfs.run() with assert_things_are_printed(things=['my_simdfsfdsgfs', 'xxxxx\nyyyyy\n']): compare_experiment_records(rec)
def test_variants(): @experiment_function def add_some_numbers(a=1, b=1): c = a + b print(c) return c with experiment_testing_context(): # Create a named variant e1 = add_some_numbers.add_variant('b is 3', b=3) assert e1.run().get_result() == 4 # Creata a sub-variant e11 = e1.add_variant('a is 2', a=2) assert e11.run().get_result() == 5 # Create unnamed variant e2 = add_some_numbers.add_variant(b=4) assert e2.run().get_result() == 5 assert e2.get_id() == 'add_some_numbers.b=4' # Create array of variants e_list = [add_some_numbers.add_variant(b=i) for i in xrange(5, 8)] assert [e.get_id() for e in e_list] == [ 'add_some_numbers.b=5', 'add_some_numbers.b=6', 'add_some_numbers.b=7' ] assert [ e.run().get_result() == j for e, j in zip(e_list, range(6, 11)) ] # Create grid of variants e_grid = [ add_some_numbers.add_variant(a=a, b=b) for a, b in itertools.product([2, 3], [4, 5, 6]) ] assert [e.get_id() for e in e_grid] == [ 'add_some_numbers.a=2,b=4', 'add_some_numbers.a=2,b=5', 'add_some_numbers.a=2,b=6', 'add_some_numbers.a=3,b=4', 'add_some_numbers.a=3,b=5', 'add_some_numbers.a=3,b=6' ] assert add_some_numbers.get_variant(a=2, b=4).run().get_result() == 6 assert add_some_numbers.get_variant(a=3, b=5).run().get_result() == 8 experiments = add_some_numbers.get_all_variants(include_roots=True, include_self=True) assert len(experiments) == 13
def test_experiment_errors(): with experiment_testing_context(new_experiment_lib=True): class MyManualException(Exception): pass @experiment_function def my_experiment_fdsgbdn(): raise MyManualException() with pytest.raises(MyManualException): my_experiment_fdsgbdn.run() with pytest.raises(MyManualException): my_experiment_fdsgbdn.run() assert my_experiment_fdsgbdn.get_latest_record().get_status() == ExpStatusOptions.ERROR
def test_view_modes(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_simdfsffsdfsfs(a=1): print('xxxxx') print('yyyyy') return a+2 my_simdfsffsdfsfs.add_variant(a=2) my_simdfsffsdfsfs.add_variant(a=3) my_simdfsffsdfsfs.run() for view_mode in ('full', 'results'): my_simdfsffsdfsfs.browse(view_mode=view_mode, command = 'q')
def test_generator_experiment(): with experiment_testing_context(new_experiment_lib=True): @experiment_root def my_generator_exp(n_steps, poison_4 = False): for i in range(n_steps): if poison_4 and i==4: raise Exception('Unlucky Number!') yield i X1 = my_generator_exp.add_variant(n_steps=5) X2 = my_generator_exp.add_variant(n_steps=5, poison_4 = True) rec1 = X1.run() rec2 = X2.run(raise_exceptions = False) assert rec1.get_result() == 4 assert rec2.get_result() == 3
def test_get_variant_records_and_delete(): with experiment_testing_context(): for record in my_api_test.get_variant_records(flat=True): record.delete() assert len(my_api_test.get_variant_records(flat=True))==0 my_api_test.run() my_api_test.get_variant('a2b2').run() assert len(my_api_test.get_variant_records(flat=True))==2 for record in my_api_test.get_variant_records(flat=True): record.delete() assert len(my_api_test.get_variant_records(flat=True))==0
def test_view_modes(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_simdfsffsdfsfs(a=1): print('xxxxx') print('yyyyy') return a + 2 my_simdfsffsdfsfs.add_variant(a=2) my_simdfsffsdfsfs.add_variant(a=3) my_simdfsffsdfsfs.run() for view_mode in ('full', 'results'): my_simdfsffsdfsfs.browse(view_mode=view_mode, command='q')
def test_get_variant_records_and_delete(): with experiment_testing_context(): for record in my_api_test.get_variant_records(flat=True): record.delete() assert len(my_api_test.get_variant_records(flat=True)) == 0 my_api_test.run() my_api_test.get_variant('a2b2').run() assert len(my_api_test.get_variant_records(flat=True)) == 2 for record in my_api_test.get_variant_records(flat=True): record.delete() assert len(my_api_test.get_variant_records(flat=True)) == 0
def test_experiment_errors(): with experiment_testing_context(new_experiment_lib=True): class MyManualException(Exception): pass @experiment_function def my_experiment_fdsgbdn(): raise MyManualException() with pytest.raises(MyManualException): my_experiment_fdsgbdn.run() with pytest.raises(MyManualException): my_experiment_fdsgbdn.run() assert my_experiment_fdsgbdn.get_latest_record().get_status( ) == ExpStatusOptions.ERROR
def test_current_experiment_access_functions(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_experiment_dfgsdgfdaf(a=1): assert a == 4, 'Only meant to run aaa variant.' experiment_id = get_current_experiment_id() rec = get_current_experiment_record() record_id = get_current_record_id() assert record_id.endswith('-' + experiment_id) assert experiment_id == 'my_experiment_dfgsdgfdaf.aaa' loc = get_current_record_dir() _, record_dir = os.path.split(loc) assert record_dir == record_id assert os.path.isdir(loc) assert loc.endswith(record_id) with open_in_record_dir('somefile.pkl', 'wb') as f: pickle.dump([1, 2, 3], f, protocol=pickle.HIGHEST_PROTOCOL) assert os.path.exists(os.path.join(loc, 'somefile.pkl')) with open_in_record_dir('somefile.pkl', 'rb') as f: assert pickle.load(f) == [1, 2, 3] exp = rec.get_experiment() assert exp.get_id() == experiment_id assert exp.get_args() == rec.get_args() == OrderedDict([('a', 4)]) assert rec.get_dir() == loc assert has_experiment_record(experiment_id) assert record_id in experiment_id_to_record_ids(experiment_id) return a + 2 v = my_experiment_dfgsdgfdaf.add_variant('aaa', a=4) v.run()
def test_current_experiment_access_functions(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_experiment_dfgsdgfdaf(a=1): assert a==4, 'Only meant to run aaa variant.' experiment_id = get_current_experiment_id() rec = get_current_experiment_record() record_id = get_current_record_id() assert record_id.endswith('-'+experiment_id) assert experiment_id == 'my_experiment_dfgsdgfdaf.aaa' loc = get_current_record_dir() _, record_dir = os.path.split(loc) assert record_dir == record_id assert os.path.isdir(loc) assert loc.endswith(record_id) with open_in_record_dir('somefile.pkl', 'wb') as f: pickle.dump([1, 2, 3], f, protocol=pickle.HIGHEST_PROTOCOL) assert os.path.exists(os.path.join(loc, 'somefile.pkl')) with open_in_record_dir('somefile.pkl', 'rb') as f: assert pickle.load(f) == [1, 2, 3] exp = rec.get_experiment() assert exp.get_id() == experiment_id assert exp.get_args() == rec.get_args() == OrderedDict([('a', 4)]) assert rec.get_dir() == loc assert has_experiment_record(experiment_id) assert record_id in experiment_id_to_record_ids(experiment_id) return a+2 v = my_experiment_dfgsdgfdaf.add_variant('aaa', a=4) v.run()
def test_figure_saving_and_loading(): from artemis.plotting.db_plotting import dbplot with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_exp(): for t in range(4): dbplot(np.random.randn(20, 20, 3), 'plot') save_figure_in_record() rec = my_exp.run() # type: ExperimentRecord fig_locs = rec.get_figure_locs() assert set(fig_locs) == { os.path.join(rec.get_dir(), 'fig-{}.pkl'.format(i)) for i in range(4) }
def test_generator_experiment(): with experiment_testing_context(new_experiment_lib=True): @experiment_root def my_generator_exp(n_steps, poison_4=False): for i in range(n_steps): if poison_4 and i == 4: raise Exception('Unlucky Number!') yield i X1 = my_generator_exp.add_variant(n_steps=5) X2 = my_generator_exp.add_variant(n_steps=5, poison_4=True) rec1 = X1.run() rec2 = X2.run(raise_exceptions=False) assert rec1.get_result() == 4 assert rec2.get_result() == 3
def test_config_bug_catching(): with experiment_testing_context(new_experiment_lib=True): @experiment_root def demo_smooth_out_signal_testing(smoother, seed=1234): signal = np.sin(np.linspace( 0, 10, 100)) + 0.1 * np.random.RandomState(seed).randn(100) y = np.array([smoother(xt) for xt in signal]) return y # The right way X = demo_smooth_out_signal_testing.add_config_variant( 'exp_smooth', smoother=lambda decay=0.1: _ExponentialMovingAverageForTestingPurposes(decay)) with pytest.raises(AssertionError): # Arg name already used! X = demo_smooth_out_signal_testing.add_config_variant( 'exp_smooth3', smoother=lambda seed=0.1: _ExponentialMovingAverageForTestingPurposes(decay=seed)) with pytest.raises( AssertionError ): # Make sure we catch when we do not give a callable X = demo_smooth_out_signal_testing.add_config_variant( 'exp_smooth4', smoother=0.1) with pytest.raises(AssertionError ): # Make sure we catch when we give the wrong name X = demo_smooth_out_signal_testing.add_config_variant( 'exp_smooth5', smOOOOther=lambda decay: _ExponentialMovingAverageForTestingPurposes(decay=decay)) with pytest.raises(AssertionError ): # Catch when we accidentally give an instance: X = demo_smooth_out_signal_testing.add_config_variant( 'exp_smooth6', smoother=_ExponentialMovingAverageForTestingPurposes( decay=0.1))
def demo_browse_record_figs(): from artemis.plotting.db_plotting import dbplot from matplotlib import pyplot as plt with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_exp(): for t in range(4): pts = np.linspace(0, 3 * (t + 1), 400) dbplot((pts * np.cos(pts), pts * np.sin(pts)), 'plot', title='t={}'.format(t), plot_type='line') save_figure_in_record() plt.close(plt.gcf()) rec = my_exp.run() # type: ExperimentRecord browse_record_figs(rec)
def test_parallel_run_errors(): with experiment_testing_context(): @experiment_function def my_error_causing_test(a=1): raise Exception('nononono') my_error_causing_test.add_variant(a=2) variants = my_error_causing_test.get_all_variants() assert len(variants)==2 run_multiple_experiments(variants, parallel=True, raise_exceptions=False) with pytest.raises(Exception) as err: run_multiple_experiments(variants, parallel=True, raise_exceptions=True) print("^^^ Dont't worry, the above is not actually an error, we were just asserting that we caught the error.") assert str(err.value) == 'nononono'
def test_invalid_arg_detection_2(): """ Check that we notice when an experiment is redefined with new args. """ with experiment_testing_context(new_experiment_lib=True): a = {"a%s" % i for i in range(100)} @experiment_function def my_experiment_gfdsbhtds(a=a): return None rec = my_experiment_gfdsbhtds.run() assert rec.args_valid() is True clear_all_experiments() @experiment_function def my_experiment_gfdsbhtds(a=a): return None assert rec.args_valid() is True # Assert that the args still match
def test_accessing_experiment_dir(): with experiment_testing_context(): @experiment_function def access_dir_test(): print('123') print('abc') dir = get_current_record_dir() with open_in_record_dir('my_test_file.txt', 'w') as f: f.write('Experiment Directory is: {}'.format(dir)) record = access_dir_test.run() filepaths = record.list_files() assert 'my_test_file.txt' in filepaths with record.open_file('my_test_file.txt') as f: assert f.read() == 'Experiment Directory is: {}'.format(record.get_dir()) with record.open_file('output.txt') as f: assert f.read() == '123\nabc\n'
def test_invalid_arg_detection_2(): """ Check that we notice when an experiment is redefined with new args. """ with experiment_testing_context(new_experiment_lib=True): a = {"a%s"%i for i in range(100)} @experiment_function def my_experiment_gfdsbhtds(a=a): return None rec = my_experiment_gfdsbhtds.run() assert rec.args_valid() is True clear_all_experiments() @experiment_function def my_experiment_gfdsbhtds(a=a): return None assert rec.args_valid() is True # Assert that the args still match
def test_config_variant(): with experiment_testing_context(new_experiment_lib=True): @experiment_root def demo_smooth_out_signal(smoother, seed = 1234): signal = np.sin(np.linspace(0, 10, 100)) + 0.1*np.random.RandomState(seed).randn(100) y = np.array([smoother(xt) for xt in signal]) return y X = demo_smooth_out_signal.add_config_variant('exp_smooth', smoother = lambda decay=0.1: _ExponentialMovingAverageForTestingPurposes(decay)) answer = X.run().get_result() assert answer.shape==(100, ) assert np.array_equal(X.run().get_result(), answer) # Check that we're definitely making a new one each time # Make sure we can still configure new args X2 = X.add_variant(seed=1235) answer2 = X2.run().get_result() assert answer2.shape == (100, ) assert not np.array_equal(answer, answer2) # Just check for no bugs in UI demo_smooth_out_signal.browse(command='q')
def test_parameter_search(): from skopt.space import Real with experiment_testing_context(new_experiment_lib=True): @experiment_root def bowl(x, y): return {'z': (x - 2)**2 + (y + 3)**2} ex_search = bowl.add_parameter_search( space={ 'x': Real(-5, 5, 'uniform'), 'y': Real(-5, 5, 'uniform') }, scalar_func=lambda result: result['z'], search_params=dict(n_calls=5), ) record = ex_search.run() result = record.get_result() assert result['names'] == ['x', 'y'] assert result['func_vals'][-1] < result['func_vals'][0]
def test_variants(): @experiment_function def add_some_numbers(a=1, b=1): c=a+b print(c) return c with experiment_testing_context(): # Create a named variant e1=add_some_numbers.add_variant('b is 3', b=3) assert e1.run().get_result()==4 # Creata a sub-variant e11 = e1.add_variant('a is 2', a=2) assert e11.run().get_result() == 5 # Create unnamed variant e2=add_some_numbers.add_variant(b=4) assert e2.run().get_result()==5 assert e2.get_id() == 'add_some_numbers.b=4' # Create array of variants e_list = [add_some_numbers.add_variant(b=i) for i in xrange(5, 8)] assert [e.get_id() for e in e_list] == ['add_some_numbers.b=5', 'add_some_numbers.b=6', 'add_some_numbers.b=7'] assert [e.run().get_result()==j for e, j in zip(e_list, range(6, 11))] # Create grid of variants e_grid = [add_some_numbers.add_variant(a=a, b=b) for a, b in itertools.product([2, 3], [4, 5, 6])] assert [e.get_id() for e in e_grid] == ['add_some_numbers.a=2,b=4', 'add_some_numbers.a=2,b=5', 'add_some_numbers.a=2,b=6', 'add_some_numbers.a=3,b=4', 'add_some_numbers.a=3,b=5', 'add_some_numbers.a=3,b=6'] assert add_some_numbers.get_variant(a=2, b=4).run().get_result()==6 assert add_some_numbers.get_variant(a=3, b=5).run().get_result()==8 experiments = add_some_numbers.get_all_variants(include_roots=True, include_self=True) assert len(experiments)==13
def test_invalid_arg_text(): with experiment_testing_context(new_experiment_lib=True): @experiment_function def my_invalid_arg_test(a=1, b={'c': 3, 'd': 4}): return a+b['c']+b['d'] record = my_invalid_arg_test.run() assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>' clear_all_experiments() @experiment_function def my_invalid_arg_test(a=2, b={'c': 3, 'd': 4}): return a+b['c']+b['d'] assert get_record_invalid_arg_string(record, recursive=True) == 'Change: a:1->2' clear_all_experiments() @experiment_function def my_invalid_arg_test(a=2, b={'c': 3, 'd': 2}): return a+b['c']+b['d'] assert get_record_invalid_arg_string(record, recursive=True) == "Change: a:1->2, b['d']:4->2"
def test_config_bug_catching(): with experiment_testing_context(new_experiment_lib=True): @experiment_root def demo_smooth_out_signal_testing(smoother, seed = 1234): signal = np.sin(np.linspace(0, 10, 100)) + 0.1*np.random.RandomState(seed).randn(100) y = np.array([smoother(xt) for xt in signal]) return y # The right way X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth', smoother = lambda decay=0.1: _ExponentialMovingAverageForTestingPurposes(decay)) with pytest.raises(AssertionError): # Arg name already used! X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth3', smoother = lambda seed=0.1: _ExponentialMovingAverageForTestingPurposes(decay=seed)) with pytest.raises(AssertionError): # Make sure we catch when we do not give a callable X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth4', smoother = 0.1) with pytest.raises(AssertionError): # Make sure we catch when we give the wrong name X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth5', smOOOOther = lambda decay: _ExponentialMovingAverageForTestingPurposes(decay=decay)) with pytest.raises(AssertionError): # Catch when we accidentally give an instance: X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth6', smoother = _ExponentialMovingAverageForTestingPurposes(decay=0.1))
def test_accessing_experiment_dir(): with experiment_testing_context(): @experiment_function def access_dir_test(): print('123') print('abc') dir = get_current_record_dir() with open_in_record_dir('my_test_file.txt', 'w') as f: f.write('Experiment Directory is: {}'.format(dir)) record = access_dir_test.run() filepaths = record.list_files() assert 'my_test_file.txt' in filepaths with record.open_file('my_test_file.txt') as f: assert f.read() == 'Experiment Directory is: {}'.format( record.get_dir()) with record.open_file('output.txt') as f: assert f.read() == '123\nabc\n'
def test_saving_result(): # Run root experiment with experiment_testing_context(): rec = add_some_numbers_test_experiment.run() assert rec.get_result() == 2