Beispiel #1
0
def test_parallel_run_errors():

    with experiment_testing_context():

        @experiment_function
        def my_error_causing_test(a=1):
            raise Exception('nononono')

        my_error_causing_test.add_variant(a=2)

        variants = my_error_causing_test.get_all_variants()

        assert len(variants) == 2

        run_multiple_experiments(variants,
                                 parallel=True,
                                 raise_exceptions=False)

        with pytest.raises(Exception) as err:
            run_multiple_experiments(variants,
                                     parallel=True,
                                     raise_exceptions=True)
        print(
            "^^^ Dont't worry, the above is not actually an error, we were just asserting that we caught the error."
        )

        assert str(err.value) == 'nononono'
Beispiel #2
0
def test_invalid_arg_detection():
    """
    Check that we notice when an experiment is redefined with new args.
    """

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 7]}):
            return a + 1

        rec = my_experiment_gfdsbhtds.run()

        assert rec.args_valid()
        clear_all_experiments()

        @experiment_function
        def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 7]}):
            return a + 1

        assert rec.args_valid()  # Assert that the args still match
        clear_all_experiments()

        @experiment_function
        def my_experiment_gfdsbhtds(a=1,
                                    b=[2, 3.],
                                    c={
                                        'a': 5,
                                        'b': [6, 8]
                                    }):  # CHANGE IN ARGS!
            return a + 1

        assert not rec.args_valid()
Beispiel #3
0
def test_invalid_arg_text():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_invalid_arg_test(a=1, b={'c': 3, 'd': 4}):
            return a + b['c'] + b['d']

        record = my_invalid_arg_test.run()
        assert get_record_invalid_arg_string(record,
                                             recursive=True) == '<No Change>'
        clear_all_experiments()

        @experiment_function
        def my_invalid_arg_test(a=2, b={'c': 3, 'd': 4}):
            return a + b['c'] + b['d']

        assert get_record_invalid_arg_string(
            record, recursive=True) == 'Change: {a:1}->{a:2}'
        clear_all_experiments()

        @experiment_function
        def my_invalid_arg_test(a=2, b={'c': 3, 'd': 2}):
            return a + b['c'] + b['d']

        assert get_record_invalid_arg_string(
            record, recursive=True) == "Change: {a:1,b['d']:4}->{a:2,b['d']:2}"
def test_get_latest():
    with experiment_testing_context():
        record_1 = experiment_test_function.run()
        time.sleep(0.01)
        record_2 = experiment_test_function.run()
        identifier = load_experiment('experiment_test_function').get_latest_record().get_id()
        assert identifier == record_2.get_id()
def test_invalid_arg_text_when_object_arg():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_unhashable_arg_test(a=MyArgumentObject(a=3)):
            return a.a+2

        record = my_unhashable_arg_test.run()
        assert record.get_result() == 5

        assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>'

        # ---------------------
        clear_all_experiments()

        @experiment_function
        def my_unhashable_arg_test(a=MyArgumentObject(a=3)):
            return a.a+2

        assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>'

        # ---------------------
        clear_all_experiments()

        @experiment_function
        def my_unhashable_arg_test(a=MyArgumentObject(a=4)):
            return a.a+2

        assert get_record_invalid_arg_string(record, recursive=True) == 'Change: a.a:3->4'
def test_experiments_function_additions():

    with experiment_testing_context():

        for rec in my_xxxyyy_test_experiment.get_variant_records(flat=True):
            rec.delete()

        r1=my_xxxyyy_test_experiment.run()
        r2=my_xxxyyy_test_experiment.get_variant('a2').run()
        with pytest.raises(Exception):
            my_xxxyyy_test_experiment.get_variant(b=17).run()
        r3 = my_xxxyyy_test_experiment.get_variant(b=17).get_latest_record()

        assert r1.get_log() == 'xxx\n'
        assert r2.get_log() == 'yyy\n'

        assert get_oneline_result_string(my_xxxyyy_test_experiment.get_latest_record()) == '3bbb'
        assert get_oneline_result_string(my_xxxyyy_test_experiment.get_variant('a2').get_latest_record()) == '4bbb'
        assert get_oneline_result_string(my_xxxyyy_test_experiment.get_variant(b=17).get_latest_record()) == '<No result has been saved>'

        with CaptureStdOut() as cap:
            my_xxxyyy_test_experiment.show(my_xxxyyy_test_experiment.get_latest_record())
        assert cap.read() == '3aaa\n'

        with CaptureStdOut() as cap:
            my_xxxyyy_test_experiment.compare([r1, r2])
        assert cap.read() == 'my_xxxyyy_test_experiment: 3, my_xxxyyy_test_experiment.a2: 4\n'

        print('='*100+'\n ARGTABLE \n'+'='*100)
        print_experiment_record_argtable([r1, r2, r3])

        print('='*100+'\n SHOW \n'+'='*100)
        compare_experiment_records([r1, r2, r3])
def test_experiment_function_ui():

    with experiment_testing_context():
        for existing_record in my_xxxyyy_test_experiment.get_variant_records(flat=True):
            existing_record.delete()

        assert len(my_xxxyyy_test_experiment.get_variant_records(flat=True)) == 0

        my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='run all', close_after=True)
        assert len(my_xxxyyy_test_experiment.get_variant_records()) == 3

        import time
        time.sleep(0.1)

        with assert_things_are_printed(min_len=1200, things=['Common Args', 'Different Args', 'Result', 'a=1, b=2', 'a=2, b=2', 'a=1, b=17']):
            my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='argtable all', close_after=True)

        with assert_things_are_printed(min_len=600, things=['my_xxxyyy_test_experiment: 3', 'my_xxxyyy_test_experiment.a2: 4']):
            my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='compare all -r', close_after=True)

        with assert_things_are_printed(min_len=1300, things=['Result', 'Logs', 'Ran Succesfully', 'Traceback']):
            my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='show all -o', close_after=True)

        with assert_things_are_printed(min_len=1250, things=['3aaa', '4aaa']):
            my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='show all -r', close_after=True)

        with assert_things_are_printed(min_len=1700, things=['Result', 'Logs', 'Ran Succesfully']):
            my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='show 0 -o', close_after=True)
Beispiel #8
0
def test_invalid_arg_text_when_object_arg():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_unhashable_arg_test(a=MyArgumentObject(a=3)):
            return a.a + 2

        record = my_unhashable_arg_test.run()
        assert record.get_result() == 5

        assert get_record_invalid_arg_string(record,
                                             recursive=True) == '<No Change>'

        # ---------------------
        clear_all_experiments()

        @experiment_function
        def my_unhashable_arg_test(a=MyArgumentObject(a=3)):
            return a.a + 2

        assert get_record_invalid_arg_string(record,
                                             recursive=True) == '<No Change>'

        # ---------------------
        clear_all_experiments()

        @experiment_function
        def my_unhashable_arg_test(a=MyArgumentObject(a=4)):
            return a.a + 2

        assert get_record_invalid_arg_string(
            record, recursive=True) == 'Change: {a.a:3}->{a.a:4}'
def test_invalid_arg_detection():
    """
    Check that we notice when an experiment is redefined with new args.
    """

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 7]}):
            return a+1

        rec = my_experiment_gfdsbhtds.run()

        assert rec.args_valid()
        clear_all_experiments()

        @experiment_function
        def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 7]}):
            return a+1

        assert rec.args_valid()  # Assert that the args still match
        clear_all_experiments()

        @experiment_function
        def my_experiment_gfdsbhtds(a=1, b=[2, 3.], c={'a': 5, 'b': [6, 8]}):  # CHANGE IN ARGS!
            return a+1

        assert not rec.args_valid()
Beispiel #10
0
def test_config_variant():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_root
        def demo_smooth_out_signal(smoother, seed=1234):
            signal = np.sin(np.linspace(
                0, 10, 100)) + 0.1 * np.random.RandomState(seed).randn(100)
            y = np.array([smoother(xt) for xt in signal])
            return y

        X = demo_smooth_out_signal.add_config_variant(
            'exp_smooth',
            smoother=lambda decay=0.1:
            _ExponentialMovingAverageForTestingPurposes(decay))
        answer = X.run().get_result()
        assert answer.shape == (100, )
        assert np.array_equal(
            X.run().get_result(),
            answer)  # Check that we're definitely making a new one each time

        # Make sure we can still configure new args
        X2 = X.add_variant(seed=1235)
        answer2 = X2.run().get_result()
        assert answer2.shape == (100, )
        assert not np.array_equal(answer, answer2)

        # Just check for no bugs in UI
        demo_smooth_out_signal.browse(command='q')
Beispiel #11
0
def test_get_latest():
    with experiment_testing_context():
        record_1 = experiment_test_function.run()
        time.sleep(0.01)
        record_2 = experiment_test_function.run()
        identifier = load_experiment(
            'experiment_test_function').get_latest_record().get_id()
        assert identifier == record_2.get_id()
def test_experiments_play_well_with_debug():

    with experiment_testing_context():

        @experiment_function
        def my_simple_test():
            plt.show._needmain=False  # pyplot does this internally whenever a breakpoint is reached.
            return 1

        my_simple_test.run()
Beispiel #13
0
def test_experiments_play_well_with_debug():

    with experiment_testing_context():

        @experiment_function
        def my_simple_test():
            plt.show._needmain = False  # pyplot does this internally whenever a breakpoint is reached.
            return 1

        my_simple_test.run()
def test_figure_saving(show_them = False):

    with experiment_testing_context():
        record = experiment_test_function.run()

    plt.close('all')  # Close all figures
    figs = record.load_figures()
    assert len(figs)==2
    if show_them:
        plt.show()
def test_get_latest_identifier():

    with experiment_testing_context():
        exp_rec = experiment_test_function.run()
        print(get_experiment_info('experiment_test_function'))
        assert_experiment_record_is_correct(exp_rec)
        last_experiment_identifier = load_experiment('experiment_test_function').get_latest_record().get_id()
        assert last_experiment_identifier is not None, 'Experiment was run, this should not be none'
        same_exp_rec = load_experiment_record(last_experiment_identifier)
        assert_experiment_record_is_correct(same_exp_rec)
Beispiel #16
0
def test_figure_saving(show_them=False):

    with experiment_testing_context():
        record = experiment_test_function.run()

    plt.close('all')  # Close all figures
    figs = record.load_figures()
    assert len(figs) == 2
    if show_them:
        plt.show()
Beispiel #17
0
def test_unpicklable_args():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_root
        def my_parametrizable_exp(f, x):
            return f(x)

        X = my_parametrizable_exp.add_variant(f=(lambda x: 2 * x), x=3)
        X.run()
        X.browse(command='q')
Beispiel #18
0
def test_unpicklable_args():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_root
        def my_parametrizable_exp(f, x):
            return f(x)

        X = my_parametrizable_exp.add_variant(f = (lambda x: 2*x), x=3)
        X.run()
        X.browse(command='q')
Beispiel #19
0
def test_get_latest_identifier():

    with experiment_testing_context():
        exp_rec = experiment_test_function.run()
        print(get_experiment_info('experiment_test_function'))
        assert_experiment_record_is_correct(exp_rec)
        last_experiment_identifier = load_experiment(
            'experiment_test_function').get_latest_record().get_id()
        assert last_experiment_identifier is not None, 'Experiment was run, this should not be none'
        same_exp_rec = load_experiment_record(last_experiment_identifier)
        assert_experiment_record_is_correct(same_exp_rec)
Beispiel #20
0
def test_run_multiple_experiments():

    with experiment_testing_context():

        experiments = my_api_test.get_all_variants()
        assert len(experiments) == 3

        records = run_multiple_experiments(experiments)
        assert [record.get_result() for record in records] == [3, 4, 6]

        records = run_multiple_experiments(experiments, parallel=True)
        assert [record.get_result() for record in records] == [3, 4, 6]
def test_experiment_with():
    """
    DEPRECATED INTERFACE

    This syntax uses the record_experiment function directly instead of hiding it.
    """

    with experiment_testing_context():
        delete_experiment_with_id('test_exp')
        with record_experiment(identifier = 'test_exp', print_to_console=True) as exp_rec:
            experiment_test_function()
            assert_experiment_record_is_correct(exp_rec, show_figures=False)
def test_run_multiple_experiments():

    with experiment_testing_context():

        experiments = my_api_test.get_all_variants()
        assert len(experiments)==3

        records = run_multiple_experiments(experiments)
        assert [record.get_result() for record in records] == [3, 4, 6]

        records = run_multiple_experiments(experiments, parallel=True)
        assert [record.get_result() for record in records] == [3, 4, 6]
def test_experiment_corrupt_detection():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_experiment_bfdsssdvs(a=1):
            return a+2

        r = my_experiment_bfdsssdvs.run()
        assert r.get_status() == ExpStatusOptions.FINISHED
        os.remove(os.path.join(r.get_dir(), 'info.pkl'))  # Manually remove the info file
        r = my_experiment_bfdsssdvs.get_latest_record()
        assert r.get_status() == ExpStatusOptions.CORRUPT
def test_experiment_api(try_browse=False):

    with experiment_testing_context():
        my_api_test.get_variant('a2b2').run()
        record = my_api_test.get_variant('a2b2').get_latest_record()

        assert record.get_log() == 'aaa\n'
        assert record.get_result() == 4
        assert record.get_args() == OrderedDict([('a', 2), ('b', 2)])
        assert record.get_status() == ExpStatusOptions.FINISHED

    if try_browse:
        my_api_test.browse()
Beispiel #25
0
def test_experiment_api(try_browse=False):

    with experiment_testing_context():
        my_api_test.get_variant('a2b2').run()
        record = my_api_test.get_variant('a2b2').get_latest_record()

        assert record.get_log() == 'aaa\n'
        assert record.get_result() == 4
        assert record.get_args() == OrderedDict([('a', 2), ('b', 2)])
        assert record.get_status() == ExpStatusOptions.FINISHED

    if try_browse:
        my_api_test.browse()
Beispiel #26
0
def test_experiment_with():
    """
    DEPRECATED INTERFACE

    This syntax uses the record_experiment function directly instead of hiding it.
    """

    with experiment_testing_context():
        delete_experiment_with_id('test_exp')
        with record_experiment(identifier='test_exp',
                               print_to_console=True) as exp_rec:
            experiment_test_function()
            assert_experiment_record_is_correct(exp_rec, show_figures=False)
Beispiel #27
0
def test_experiment_corrupt_detection():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_experiment_bfdsssdvs(a=1):
            return a + 2

        r = my_experiment_bfdsssdvs.run()
        assert r.get_status() == ExpStatusOptions.FINISHED
        os.remove(os.path.join(r.get_dir(),
                               'info.pkl'))  # Manually remove the info file
        r = my_experiment_bfdsssdvs.get_latest_record()
        assert r.get_status() == ExpStatusOptions.CORRUPT
def test_start_experiment():
    """
    DEPRECATED INTERFACE

    An alternative syntax to the with statement - less tidy but possibly better
    for notebooks and such because it avoids you having to indent all code in the
    experiment.
    """

    delete_experiment_with_id('start_stop_test')
    with experiment_testing_context():
        record = start_experiment('start_stop_test')
        experiment_test_function()
        end_current_experiment()
        assert_experiment_record_is_correct(record, show_figures=False)
Beispiel #29
0
def test_start_experiment():
    """
    DEPRECATED INTERFACE

    An alternative syntax to the with statement - less tidy but possibly better
    for notebooks and such because it avoids you having to indent all code in the
    experiment.
    """

    delete_experiment_with_id('start_stop_test')
    with experiment_testing_context():
        record = start_experiment('start_stop_test')
        experiment_test_function()
        end_current_experiment()
        assert_experiment_record_is_correct(record, show_figures=False)
def test_simple_experiment_show():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_simdfsfdsgfs(a=1):

            print('xxxxx')
            print('yyyyy')
            return a+2

        rec = my_simdfsfdsgfs.run()

        with assert_things_are_printed(things=['my_simdfsfdsgfs', 'xxxxx\nyyyyy\n']):
            compare_experiment_records(rec)
Beispiel #31
0
def test_variants():
    @experiment_function
    def add_some_numbers(a=1, b=1):
        c = a + b
        print(c)
        return c

    with experiment_testing_context():

        # Create a named variant
        e1 = add_some_numbers.add_variant('b is 3', b=3)
        assert e1.run().get_result() == 4

        # Creata a sub-variant
        e11 = e1.add_variant('a is 2', a=2)
        assert e11.run().get_result() == 5

        # Create unnamed variant
        e2 = add_some_numbers.add_variant(b=4)
        assert e2.run().get_result() == 5
        assert e2.get_id() == 'add_some_numbers.b=4'

        # Create array of variants
        e_list = [add_some_numbers.add_variant(b=i) for i in xrange(5, 8)]
        assert [e.get_id() for e in e_list] == [
            'add_some_numbers.b=5', 'add_some_numbers.b=6',
            'add_some_numbers.b=7'
        ]
        assert [
            e.run().get_result() == j for e, j in zip(e_list, range(6, 11))
        ]

        # Create grid of variants
        e_grid = [
            add_some_numbers.add_variant(a=a, b=b)
            for a, b in itertools.product([2, 3], [4, 5, 6])
        ]
        assert [e.get_id() for e in e_grid] == [
            'add_some_numbers.a=2,b=4', 'add_some_numbers.a=2,b=5',
            'add_some_numbers.a=2,b=6', 'add_some_numbers.a=3,b=4',
            'add_some_numbers.a=3,b=5', 'add_some_numbers.a=3,b=6'
        ]
        assert add_some_numbers.get_variant(a=2, b=4).run().get_result() == 6
        assert add_some_numbers.get_variant(a=3, b=5).run().get_result() == 8

        experiments = add_some_numbers.get_all_variants(include_roots=True,
                                                        include_self=True)
        assert len(experiments) == 13
def test_experiment_errors():

    with experiment_testing_context(new_experiment_lib=True):

        class MyManualException(Exception):
            pass

        @experiment_function
        def my_experiment_fdsgbdn():
            raise MyManualException()

        with pytest.raises(MyManualException):
            my_experiment_fdsgbdn.run()
        with pytest.raises(MyManualException):
            my_experiment_fdsgbdn.run()

        assert my_experiment_fdsgbdn.get_latest_record().get_status() == ExpStatusOptions.ERROR
def test_view_modes():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_simdfsffsdfsfs(a=1):

            print('xxxxx')
            print('yyyyy')
            return a+2

        my_simdfsffsdfsfs.add_variant(a=2)
        my_simdfsffsdfsfs.add_variant(a=3)

        my_simdfsffsdfsfs.run()

        for view_mode in ('full', 'results'):
            my_simdfsffsdfsfs.browse(view_mode=view_mode, command = 'q')
def test_generator_experiment():

    with experiment_testing_context(new_experiment_lib=True):
        @experiment_root
        def my_generator_exp(n_steps, poison_4 = False):
            for i in range(n_steps):
                if poison_4 and i==4:
                    raise Exception('Unlucky Number!')
                yield i

        X1 = my_generator_exp.add_variant(n_steps=5)
        X2 = my_generator_exp.add_variant(n_steps=5, poison_4 = True)

        rec1 = X1.run()
        rec2 = X2.run(raise_exceptions = False)

        assert rec1.get_result() == 4
        assert rec2.get_result() == 3
def test_get_variant_records_and_delete():

    with experiment_testing_context():

        for record in my_api_test.get_variant_records(flat=True):
            record.delete()

        assert len(my_api_test.get_variant_records(flat=True))==0

        my_api_test.run()
        my_api_test.get_variant('a2b2').run()

        assert len(my_api_test.get_variant_records(flat=True))==2

        for record in my_api_test.get_variant_records(flat=True):
            record.delete()

        assert len(my_api_test.get_variant_records(flat=True))==0
Beispiel #36
0
def test_view_modes():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_simdfsffsdfsfs(a=1):

            print('xxxxx')
            print('yyyyy')
            return a + 2

        my_simdfsffsdfsfs.add_variant(a=2)
        my_simdfsffsdfsfs.add_variant(a=3)

        my_simdfsffsdfsfs.run()

        for view_mode in ('full', 'results'):
            my_simdfsffsdfsfs.browse(view_mode=view_mode, command='q')
Beispiel #37
0
def test_get_variant_records_and_delete():

    with experiment_testing_context():

        for record in my_api_test.get_variant_records(flat=True):
            record.delete()

        assert len(my_api_test.get_variant_records(flat=True)) == 0

        my_api_test.run()
        my_api_test.get_variant('a2b2').run()

        assert len(my_api_test.get_variant_records(flat=True)) == 2

        for record in my_api_test.get_variant_records(flat=True):
            record.delete()

        assert len(my_api_test.get_variant_records(flat=True)) == 0
Beispiel #38
0
def test_experiment_errors():

    with experiment_testing_context(new_experiment_lib=True):

        class MyManualException(Exception):
            pass

        @experiment_function
        def my_experiment_fdsgbdn():
            raise MyManualException()

        with pytest.raises(MyManualException):
            my_experiment_fdsgbdn.run()
        with pytest.raises(MyManualException):
            my_experiment_fdsgbdn.run()

        assert my_experiment_fdsgbdn.get_latest_record().get_status(
        ) == ExpStatusOptions.ERROR
Beispiel #39
0
def test_current_experiment_access_functions():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_experiment_dfgsdgfdaf(a=1):

            assert a == 4, 'Only meant to run aaa variant.'
            experiment_id = get_current_experiment_id()

            rec = get_current_experiment_record()

            record_id = get_current_record_id()

            assert record_id.endswith('-' + experiment_id)

            assert experiment_id == 'my_experiment_dfgsdgfdaf.aaa'
            loc = get_current_record_dir()

            _, record_dir = os.path.split(loc)

            assert record_dir == record_id
            assert os.path.isdir(loc)
            assert loc.endswith(record_id)

            with open_in_record_dir('somefile.pkl', 'wb') as f:
                pickle.dump([1, 2, 3], f, protocol=pickle.HIGHEST_PROTOCOL)

            assert os.path.exists(os.path.join(loc, 'somefile.pkl'))

            with open_in_record_dir('somefile.pkl', 'rb') as f:
                assert pickle.load(f) == [1, 2, 3]

            exp = rec.get_experiment()
            assert exp.get_id() == experiment_id
            assert exp.get_args() == rec.get_args() == OrderedDict([('a', 4)])
            assert rec.get_dir() == loc
            assert has_experiment_record(experiment_id)
            assert record_id in experiment_id_to_record_ids(experiment_id)
            return a + 2

        v = my_experiment_dfgsdgfdaf.add_variant('aaa', a=4)

        v.run()
def test_current_experiment_access_functions():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_experiment_dfgsdgfdaf(a=1):

            assert a==4, 'Only meant to run aaa variant.'
            experiment_id = get_current_experiment_id()

            rec = get_current_experiment_record()

            record_id = get_current_record_id()

            assert record_id.endswith('-'+experiment_id)

            assert experiment_id == 'my_experiment_dfgsdgfdaf.aaa'
            loc = get_current_record_dir()

            _, record_dir = os.path.split(loc)

            assert record_dir == record_id
            assert os.path.isdir(loc)
            assert loc.endswith(record_id)

            with open_in_record_dir('somefile.pkl', 'wb') as f:
                pickle.dump([1, 2, 3], f, protocol=pickle.HIGHEST_PROTOCOL)

            assert os.path.exists(os.path.join(loc, 'somefile.pkl'))

            with open_in_record_dir('somefile.pkl', 'rb') as f:
                assert pickle.load(f) == [1, 2, 3]

            exp = rec.get_experiment()
            assert exp.get_id() == experiment_id
            assert exp.get_args() == rec.get_args() == OrderedDict([('a', 4)])
            assert rec.get_dir() == loc
            assert has_experiment_record(experiment_id)
            assert record_id in experiment_id_to_record_ids(experiment_id)
            return a+2

        v = my_experiment_dfgsdgfdaf.add_variant('aaa', a=4)

        v.run()
def test_figure_saving_and_loading():

    from artemis.plotting.db_plotting import dbplot
    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_exp():
            for t in range(4):
                dbplot(np.random.randn(20, 20, 3), 'plot')
                save_figure_in_record()

        rec = my_exp.run()  # type: ExperimentRecord

        fig_locs = rec.get_figure_locs()

        assert set(fig_locs) == {
            os.path.join(rec.get_dir(), 'fig-{}.pkl'.format(i))
            for i in range(4)
        }
def test_generator_experiment():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_root
        def my_generator_exp(n_steps, poison_4=False):
            for i in range(n_steps):
                if poison_4 and i == 4:
                    raise Exception('Unlucky Number!')
                yield i

        X1 = my_generator_exp.add_variant(n_steps=5)
        X2 = my_generator_exp.add_variant(n_steps=5, poison_4=True)

        rec1 = X1.run()
        rec2 = X2.run(raise_exceptions=False)

        assert rec1.get_result() == 4
        assert rec2.get_result() == 3
Beispiel #43
0
def test_config_bug_catching():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_root
        def demo_smooth_out_signal_testing(smoother, seed=1234):
            signal = np.sin(np.linspace(
                0, 10, 100)) + 0.1 * np.random.RandomState(seed).randn(100)
            y = np.array([smoother(xt) for xt in signal])
            return y

        # The right way
        X = demo_smooth_out_signal_testing.add_config_variant(
            'exp_smooth',
            smoother=lambda decay=0.1:
            _ExponentialMovingAverageForTestingPurposes(decay))

        with pytest.raises(AssertionError):  # Arg name already used!
            X = demo_smooth_out_signal_testing.add_config_variant(
                'exp_smooth3',
                smoother=lambda seed=0.1:
                _ExponentialMovingAverageForTestingPurposes(decay=seed))

        with pytest.raises(
                AssertionError
        ):  # Make sure we catch when we do not give a callable
            X = demo_smooth_out_signal_testing.add_config_variant(
                'exp_smooth4', smoother=0.1)

        with pytest.raises(AssertionError
                           ):  # Make sure we catch when we give the wrong name
            X = demo_smooth_out_signal_testing.add_config_variant(
                'exp_smooth5',
                smOOOOther=lambda decay:
                _ExponentialMovingAverageForTestingPurposes(decay=decay))

        with pytest.raises(AssertionError
                           ):  # Catch when we accidentally give an instance:
            X = demo_smooth_out_signal_testing.add_config_variant(
                'exp_smooth6',
                smoother=_ExponentialMovingAverageForTestingPurposes(
                    decay=0.1))
def demo_browse_record_figs():

    from artemis.plotting.db_plotting import dbplot
    from matplotlib import pyplot as plt
    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_exp():
            for t in range(4):
                pts = np.linspace(0, 3 * (t + 1), 400)
                dbplot((pts * np.cos(pts), pts * np.sin(pts)),
                       'plot',
                       title='t={}'.format(t),
                       plot_type='line')
                save_figure_in_record()
            plt.close(plt.gcf())

        rec = my_exp.run()  # type: ExperimentRecord

        browse_record_figs(rec)
def test_parallel_run_errors():

    with experiment_testing_context():

        @experiment_function
        def my_error_causing_test(a=1):
            raise Exception('nononono')

        my_error_causing_test.add_variant(a=2)

        variants = my_error_causing_test.get_all_variants()

        assert len(variants)==2

        run_multiple_experiments(variants, parallel=True, raise_exceptions=False)

        with pytest.raises(Exception) as err:
            run_multiple_experiments(variants, parallel=True, raise_exceptions=True)
        print("^^^ Dont't worry, the above is not actually an error, we were just asserting that we caught the error.")

        assert str(err.value) == 'nononono'
Beispiel #46
0
def test_invalid_arg_detection_2():
    """
    Check that we notice when an experiment is redefined with new args.
    """

    with experiment_testing_context(new_experiment_lib=True):

        a = {"a%s" % i for i in range(100)}

        @experiment_function
        def my_experiment_gfdsbhtds(a=a):
            return None

        rec = my_experiment_gfdsbhtds.run()

        assert rec.args_valid() is True
        clear_all_experiments()

        @experiment_function
        def my_experiment_gfdsbhtds(a=a):
            return None

        assert rec.args_valid() is True  # Assert that the args still match
def test_accessing_experiment_dir():

    with experiment_testing_context():

        @experiment_function
        def access_dir_test():
            print('123')
            print('abc')
            dir = get_current_record_dir()
            with open_in_record_dir('my_test_file.txt', 'w') as f:
                f.write('Experiment Directory is: {}'.format(dir))

        record = access_dir_test.run()

        filepaths = record.list_files()

        assert 'my_test_file.txt' in filepaths

        with record.open_file('my_test_file.txt') as f:
            assert f.read() == 'Experiment Directory is: {}'.format(record.get_dir())

        with record.open_file('output.txt') as f:
            assert f.read() == '123\nabc\n'
def test_invalid_arg_detection_2():
    """
    Check that we notice when an experiment is redefined with new args.
    """

    with experiment_testing_context(new_experiment_lib=True):

        a = {"a%s"%i for i in range(100)}

        @experiment_function
        def my_experiment_gfdsbhtds(a=a):
            return None

        rec = my_experiment_gfdsbhtds.run()

        assert rec.args_valid() is True
        clear_all_experiments()

        @experiment_function
        def my_experiment_gfdsbhtds(a=a):
            return None

        assert rec.args_valid() is True  # Assert that the args still match
Beispiel #49
0
def test_config_variant():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_root
        def demo_smooth_out_signal(smoother, seed = 1234):
            signal = np.sin(np.linspace(0, 10, 100)) + 0.1*np.random.RandomState(seed).randn(100)
            y = np.array([smoother(xt) for xt in signal])
            return y

        X = demo_smooth_out_signal.add_config_variant('exp_smooth', smoother = lambda decay=0.1: _ExponentialMovingAverageForTestingPurposes(decay))
        answer = X.run().get_result()
        assert answer.shape==(100, )
        assert np.array_equal(X.run().get_result(), answer)  # Check that we're definitely making a new one each time

        # Make sure we can still configure new args
        X2 = X.add_variant(seed=1235)
        answer2 = X2.run().get_result()
        assert answer2.shape == (100, )
        assert not np.array_equal(answer, answer2)

        # Just check for no bugs in UI
        demo_smooth_out_signal.browse(command='q')
Beispiel #50
0
def test_parameter_search():

    from skopt.space import Real

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_root
        def bowl(x, y):
            return {'z': (x - 2)**2 + (y + 3)**2}

        ex_search = bowl.add_parameter_search(
            space={
                'x': Real(-5, 5, 'uniform'),
                'y': Real(-5, 5, 'uniform')
            },
            scalar_func=lambda result: result['z'],
            search_params=dict(n_calls=5),
        )

        record = ex_search.run()
        result = record.get_result()
        assert result['names'] == ['x', 'y']
        assert result['func_vals'][-1] < result['func_vals'][0]
def test_variants():

    @experiment_function
    def add_some_numbers(a=1, b=1):
        c=a+b
        print(c)
        return c

    with experiment_testing_context():

        # Create a named variant
        e1=add_some_numbers.add_variant('b is 3', b=3)
        assert e1.run().get_result()==4

        # Creata a sub-variant
        e11 = e1.add_variant('a is 2', a=2)
        assert e11.run().get_result() == 5

        # Create unnamed variant
        e2=add_some_numbers.add_variant(b=4)
        assert e2.run().get_result()==5
        assert e2.get_id() == 'add_some_numbers.b=4'

        # Create array of variants
        e_list = [add_some_numbers.add_variant(b=i) for i in xrange(5, 8)]
        assert [e.get_id() for e in e_list] == ['add_some_numbers.b=5', 'add_some_numbers.b=6', 'add_some_numbers.b=7']
        assert [e.run().get_result()==j for e, j in zip(e_list, range(6, 11))]

        # Create grid of variants
        e_grid = [add_some_numbers.add_variant(a=a, b=b) for a, b in itertools.product([2, 3], [4, 5, 6])]
        assert [e.get_id() for e in e_grid] == ['add_some_numbers.a=2,b=4', 'add_some_numbers.a=2,b=5', 'add_some_numbers.a=2,b=6',
                                                  'add_some_numbers.a=3,b=4', 'add_some_numbers.a=3,b=5', 'add_some_numbers.a=3,b=6']
        assert add_some_numbers.get_variant(a=2, b=4).run().get_result()==6
        assert add_some_numbers.get_variant(a=3, b=5).run().get_result()==8

        experiments = add_some_numbers.get_all_variants(include_roots=True, include_self=True)
        assert len(experiments)==13
def test_invalid_arg_text():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_function
        def my_invalid_arg_test(a=1, b={'c': 3, 'd': 4}):
            return a+b['c']+b['d']

        record = my_invalid_arg_test.run()
        assert get_record_invalid_arg_string(record, recursive=True) == '<No Change>'
        clear_all_experiments()

        @experiment_function
        def my_invalid_arg_test(a=2, b={'c': 3, 'd': 4}):
            return a+b['c']+b['d']

        assert get_record_invalid_arg_string(record, recursive=True) == 'Change: a:1->2'
        clear_all_experiments()

        @experiment_function
        def my_invalid_arg_test(a=2, b={'c': 3, 'd': 2}):
            return a+b['c']+b['d']

        assert get_record_invalid_arg_string(record, recursive=True) == "Change: a:1->2, b['d']:4->2"
Beispiel #53
0
def test_config_bug_catching():

    with experiment_testing_context(new_experiment_lib=True):

        @experiment_root
        def demo_smooth_out_signal_testing(smoother, seed = 1234):
            signal = np.sin(np.linspace(0, 10, 100)) + 0.1*np.random.RandomState(seed).randn(100)
            y = np.array([smoother(xt) for xt in signal])
            return y

        # The right way
        X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth', smoother = lambda decay=0.1: _ExponentialMovingAverageForTestingPurposes(decay))

        with pytest.raises(AssertionError):  # Arg name already used!
            X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth3', smoother = lambda seed=0.1: _ExponentialMovingAverageForTestingPurposes(decay=seed))

        with pytest.raises(AssertionError):  # Make sure we catch when we do not give a callable
            X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth4', smoother = 0.1)

        with pytest.raises(AssertionError):  # Make sure we catch when we give the wrong name
            X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth5', smOOOOther = lambda decay: _ExponentialMovingAverageForTestingPurposes(decay=decay))

        with pytest.raises(AssertionError):  # Catch when we accidentally give an instance:
            X = demo_smooth_out_signal_testing.add_config_variant('exp_smooth6', smoother = _ExponentialMovingAverageForTestingPurposes(decay=0.1))
Beispiel #54
0
def test_accessing_experiment_dir():

    with experiment_testing_context():

        @experiment_function
        def access_dir_test():
            print('123')
            print('abc')
            dir = get_current_record_dir()
            with open_in_record_dir('my_test_file.txt', 'w') as f:
                f.write('Experiment Directory is: {}'.format(dir))

        record = access_dir_test.run()

        filepaths = record.list_files()

        assert 'my_test_file.txt' in filepaths

        with record.open_file('my_test_file.txt') as f:
            assert f.read() == 'Experiment Directory is: {}'.format(
                record.get_dir())

        with record.open_file('output.txt') as f:
            assert f.read() == '123\nabc\n'
def test_saving_result():
    # Run root experiment
    with experiment_testing_context():
        rec = add_some_numbers_test_experiment.run()
        assert rec.get_result() == 2