示例#1
0
#
# License: 3-clause BSD, <https://github.com/smarie/python-pytest-cases/blob/master/LICENSE>
# test.py
import pytest
from pytest_cases import fixture_ref, parametrize


@pytest.fixture
def foo():
    return 1


@pytest.fixture
def bar():
    return 2


@parametrize("arg", [fixture_ref("foo"), fixture_ref("bar")])
def test_thing(arg):
    print(arg)


class TestCase:
    @pytest.mark.parametrize("arg", [1, 2])
    def test_thing_pytest(self, arg):
        print(arg)

    @parametrize("arg", [fixture_ref("foo"), fixture_ref("bar")])
    def test_thing_cases(self, arg):
        print(arg)
    # setup the database connection
    print("setting up dataset B")
    assert DB is None
    DB = 'DB'

    yield DB

    # teardown the database connection
    print("tearing down dataset B")
    assert DB == 'DB'
    DB = None


@pytest_fixture_plus(scope="module")
@pytest.mark.parametrize('data_index',
                         range(len(datasets_contents['datasetB'])),
                         ids="idx={}".format)
def data_from_datasetB(datasetB, data_index):
    assert datasetB == 'DB'
    return datasets_contents['datasetB'][data_index]


@pytest_parametrize_plus(
    'data',
    [fixture_ref('data_from_datasetA'),
     fixture_ref('data_from_datasetB')])
def test_databases(data):
    # do test
    print(data)
#
# License: 3-clause BSD, <https://github.com/smarie/python-pytest-cases/blob/master/LICENSE>
# test.py
import pytest
from pytest_cases import fixture_ref, pytest_parametrize_plus


@pytest.fixture
def foo():
    return 1


@pytest.fixture
def bar():
    return 2


@pytest_parametrize_plus("arg", [fixture_ref("foo"), fixture_ref("bar")])
def test_thing(arg):
    print(arg)


class TestCase:
    @pytest.mark.parametrize("arg", [1, 2])
    def test_thing_pytest(self, arg):
        print(arg)

    @pytest_parametrize_plus("arg", [fixture_ref("foo"), fixture_ref("bar")])
    def test_thing_cases(self, arg):
        print(arg)
    return val


@fixture
@pytest.mark.parametrize('val', [0, -1])
def myfix2(val):
    return val


@fixture
@pytest.mark.parametrize('a, b', [('d', 3), ('e', 4)])
def my_tuple(a, b):
    return a, b


@parametrize('p,q', [('a', 1), (fixture_ref(myfix), 2),
                     (fixture_ref(myfix), fixture_ref(myfix2)),
                     (fixture_ref(myfix), fixture_ref(myfix)),
                     fixture_ref(my_tuple)])
def test_prints(p, q):
    print(p, q)


def test_synthesis(module_results_dct):
    assert list(module_results_dct) == [
        'test_prints[a-1]',
        'test_prints[myfix-2-b]',
        'test_prints[myfix-2-c]',
        'test_prints[myfix-myfix2-b-0]',
        'test_prints[myfix-myfix2-b--1]',
        'test_prints[myfix-myfix2-c-0]',
示例#5
0
import pytest

from pytest_cases import pytest_parametrize_plus, pytest_fixture_plus, fixture_ref


@pytest.fixture
def a():
    return 'a'


@pytest_fixture_plus
@pytest_parametrize_plus('second_letter', [fixture_ref('a'),
                                           'o'])
def b(second_letter):
    # second_letter = 'a'
    return 'b' + second_letter


@pytest_parametrize_plus('arg', ['z',
                                 fixture_ref(a),
                                 fixture_ref(b),
                                 'o'])
@pytest.mark.parametrize('bar', ['bar'])
def test_foo(arg, bar):
    assert bar == 'bar'
    assert arg in ['z',
                   'a',
                   'ba',
                   'bo',
                   'o']
def test_invalid_argvalues():
    with pytest.raises(InvalidParamsList):

        @parametrize_plus('main_msg', fixture_ref(test))
        def test_prints(main_msg):
            print(main_msg)
示例#7
0
import pytest
from pytest_cases import fixture_ref, parametrize_plus

from carnival import cmd, global_context


@pytest.mark.slow
@parametrize_plus('host', [
    fixture_ref('ubuntu_ssh_host'),
])
def test_install_ce_ubuntu(suspend_capture, host):
    with suspend_capture:
        with global_context.SetContext(host):
            assert cmd.fs.is_file_exists("/usr/bin/docker") is False
            cmd.docker.install_ce_ubuntu()
            assert cmd.fs.is_file_exists("/usr/bin/docker") is True
            cmd.apt.remove("docker-ce")


@pytest.mark.slow
@parametrize_plus('host', [
    fixture_ref('ubuntu_ssh_host'),
    fixture_ref('centos_ssh_host'),
])
def test_install_compose(suspend_capture, host):
    with suspend_capture:
        with global_context.SetContext(host):
            assert cmd.fs.is_file_exists(
                "/usr/local/bin/docker-compose") is False
            cmd.docker.install_compose()
            assert cmd.fs.is_file_exists(
# ------ Datasets fixture generation
def create_dataset_fixture(dataset_name):
    @fixture(scope="module", name=dataset_name)
    def dataset():
        print("setting up dataset %s" % dataset_name)
        yield datasets[dataset_name]
        print("tearing down dataset %s" % dataset_name)

    return dataset

def create_data_from_dataset_fixture(dataset_name):
    @fixture(name="data_from_%s" % dataset_name, scope="module")
    @pytest.mark.parametrize('data_index', dataset_indices, ids="idx={}".format)
    @with_signature("(%s, data_index)" % dataset_name)
    def data_from_dataset(data_index, **kwargs):
        dataset = kwargs.popitem()[1]
        return dataset[data_index]

    return data_from_dataset

for dataset_name, dataset_indices in datasets_indices.items():
    globals()[dataset_name] = create_dataset_fixture(dataset_name)
    globals()["data_from_%s" % dataset_name] = create_data_from_dataset_fixture(dataset_name)

# ------ Test
@parametrize('data', [fixture_ref('data_from_%s' % n)
                                  for n in datasets_indices.keys()])
def test_databases(data):
    # do test
    print(data)
@pytest.fixture()
def img_arr():
    return Array[Float, 3](np.ones((3, 3, 3)))


@pytest.fixture()
def col_arr():
    return Array[Float, 4](np.ones((2, 3, 3, 3)))


@pytest.fixture()
def img_arr_masked():
    data = np.ones((3, 3, 3))
    mask = (np.arange(data.size) % 3 == 0).reshape(data.shape)
    return MaskedArray[Float, 3](data, mask)


@pytest.fixture()
def col_arr_masked():
    data = np.ones((2, 3, 3, 3))
    mask = (np.arange(data.size) % 3 == 0).reshape(data.shape)
    return MaskedArray[Float, 4](data, mask)


unmasked_arrays = [fixture_ref("img_arr"), fixture_ref("col_arr")]
masked_arrays = [fixture_ref("img_arr_masked"), fixture_ref("col_arr_masked")]
img_arrays = [fixture_ref("img_arr"), fixture_ref("img_arr_masked")]
col_arrays = [fixture_ref("col_arr"), fixture_ref("col_arr_masked")]
all_arrays = unmasked_arrays + masked_arrays
# ----------- fix for issue 213


@parametrize(name=("bar", ))
def case_foo2(name):
    return name


@parametrize_with_cases("a", cases=case_foo2)
def test_foo2(a, current_cases):
    assert current_cases == {'a': ('foo2', case_foo2, {'name': 'bar'})}


# ----------- fix for issue 213 bis


@fixture
def o():
    return "name"


@fixture
@parametrize("a", (fixture_ref(o), 'r'))
def a_fix(a):
    return a


def test_foo3(a_fix, current_cases):
    assert current_cases == {}
示例#11
0
@pytest_fixture_plus
@pytest.mark.parametrize('val', [0, -1])
def myfix2(val):
    return val


@pytest_fixture_plus
@pytest.mark.parametrize('val', [('d', 3),
                                 ('e', 4)])
def my_tuple(val):
    return val


@pytest_parametrize_plus('p,q', [('a', 1),
                                 (fixture_ref(myfix), 2),
                                 (fixture_ref(myfix), fixture_ref(myfix2)),
                                 (fixture_ref(myfix), fixture_ref(myfix)),
                                 fixture_ref(my_tuple)])
def test_prints(p, q):
    print(p, q)


def test_synthesis(module_results_dct):
    assert list(module_results_dct) == ['test_prints[p_q_is_a-1]',
                                        'test_prints[p_q_is_P1-b]',
                                        'test_prints[p_q_is_P1-c]',
                                        'test_prints[p_q_is_P2-b-0]',
                                        'test_prints[p_q_is_P2-b--1]',
                                        'test_prints[p_q_is_P2-c-0]',
                                        'test_prints[p_q_is_P2-c--1]',
示例#12
0

@fixture_plus(scope="session")
@parametrize_plus("namespace", ["", "airflow-worker-pods"])
def pod_namespace(namespace):
    return namespace


@fixture_plus(scope="session")
@parametrize_plus("executor_name", ["CeleryExecutor", "KubernetesExecutor"])
def executor(executor_name):
    return executor_name


@parametrize_plus("executor, pod_namespace",
                  [(fixture_ref(executor), fixture_ref(pod_namespace))])
@pytest.fixture(scope="session")
def airflow_options(executor, pod_namespace):
    return AirflowOptions(
        dag_sync_image="alpine/git",
        dag_sync_command=[
            "/bin/sh", "-c",
            parse_shell_script(str(dag_copy_loc))
        ],
        dag_sync_schedule="* * * * *",
        default_timezone="est",
        core_executor=executor,
        open_node_ports=True,
        local_mode=True,
        pods_namespace=pod_namespace,
    )
示例#13
0
@fixture(scope="module")
def datasetB():
    global DB

    # setup the database connection
    print("setting up dataset B")
    assert DB is None
    DB = 'DB'

    yield DB

    # teardown the database connection
    print("tearing down dataset B")
    assert DB == 'DB'
    DB = None


@fixture(scope="module")
@pytest.mark.parametrize('data_index', range(len(datasets_contents['datasetB'])), ids="idx={}".format)
def data_from_datasetB(datasetB, data_index):
    assert datasetB == 'DB'
    return datasets_contents['datasetB'][data_index]


@parametrize('data', [fixture_ref('data_from_datasetA'),
                                  fixture_ref('data_from_datasetB')])
def test_databases(data):
    # do test
    print(data)
import pytest
from pytest_cases.common_pytest_marks import has_pytest_param

from pytest_cases import fixture, parametrize, fixture_ref


if has_pytest_param:
    @fixture
    def b():
        print("b")
        return "b"


    @parametrize("fixture", [fixture_ref(b),
                             pytest.param(fixture_ref(b))
                             ])
    def test(fixture):
        assert fixture == "b"
        print("Test ran fixure %s" % fixture)


    @parametrize("fixture,a", [(fixture_ref(b), 1),
                               pytest.param(fixture_ref(b), 1)
                               ])
    def test2(fixture, a):
        assert fixture == "b"
        assert a == 1
        print("Test ran fixure %s" % fixture)
#
# License: 3-clause BSD, <https://github.com/smarie/python-pytest-cases/blob/master/LICENSE>
import pytest
from pytest_cases import parametrize_plus, fixture_plus, fixture_ref, lazy_value


@pytest.fixture
def world_str():
    return 'world'


def whatfun():
    return 'what'


@fixture_plus
@parametrize_plus('who', [fixture_ref(world_str), 'you'])
def greetings(who):
    return 'hello ' + who


@parametrize_plus('main_msg', [
    'nothing',
    fixture_ref(world_str),
    lazy_value(whatfun),
    fixture_ref(greetings)
])
@pytest.mark.parametrize('ending', ['?', '!'])
def test_prints(main_msg, ending):
    print(main_msg + ending)

# pytest.param is not available in all versions
if has_pytest_param:
    @pytest.fixture
    def a():
        return 'a'


    @pytest.fixture
    def b():
        return 'b'


    @parametrize_plus('arg', [pytest.param("a", marks=pytest.mark.skipif("5>4")),
                              fixture_ref(b)])
    def test_mark(arg):
        assert arg in ['a', 'b']


    @parametrize_plus('arg', [pytest.param("a", id="testID"),
                              fixture_ref(b)])
    def test_id(arg):
        assert arg in ['a', 'b']


    def test_synthesis(module_results_dct):
        # make sure the id and skip mark were taken into account
        assert list(module_results_dct) == [
            'test_mark[arg_is_b]',
            'test_id[testID]',
has_pytest_param = hasattr(pytest, 'param')

# pytest.param is not available in all versions
if has_pytest_param:

    @pytest.fixture
    def a():
        return 'a'

    @pytest.fixture(params=['r', 't'], ids="b={}".format)
    def b(request):
        return "b%s" % request.param

    @parametrize_plus('foo', [
        1,
        fixture_ref(b),
        pytest.param('t'),
        pytest.param('r', id='W'), 3,
        pytest.param(fixture_ref(a)),
        fixture_ref(a)
    ],
                      ids=[str(i) for i in range(7)])
    def test_id(foo):
        pass

    def test_synthesis(module_results_dct):
        # make sure the id and skip mark were taken into account
        assert list(module_results_dct) == [
            'test_id[0]', 'test_id[1-b=r]', 'test_id[1-b=t]',
            'test_id[foo_is_P2toP4-2]', 'test_id[foo_is_P2toP4-W]',
            'test_id[foo_is_P2toP4-4]', 'test_id[5]', 'test_id[6]'
示例#18
0
class TestFullData:
    """Tests on a realistic dataset."""
    def test_correct_num_force_plates(self, full_data):
        assert len(full_data.forcepl) == 2

    def test_there_is_emg(self, full_data):
        assert full_data.emg is not None

    def test_correct_num_traj(self, full_data):
        assert len(full_data.traj) == 40

    @parametrize(
        "devices, exp_names",
        [
            (fixture_ref("full_data_forcep"),
             fixture_ref("full_data_forcep_names")),
            (fixture_ref("full_data_emg_list"),
             fixture_ref("full_data_emg_names")),
            (fixture_ref("full_data_traj"),
             fixture_ref("full_data_traj_names")),
        ],
    )
    def test_load_correct_names(self, devices, exp_names):
        for (dev, name) in zip(devices, exp_names):
            assert dev.name == name

    @parametrize(
        "devices, exp_cols",
        [
            (fixture_ref("full_data_forcep"), fixture_ref("forcep_cols")),
            (fixture_ref("full_data_emg_list"), fixture_ref("emg_cols")),
            (fixture_ref("full_data_traj"), fixture_ref("traj_cols")),
        ],
    )
    def test_correct_cols(self, devices, exp_cols):
        for dev in devices:
            coords = tuple(dev.df.columns)
            assert coords == tuple(exp_cols)

    @parametrize(
        "devices, exp_units",
        [
            (fixture_ref("full_data_forcep"), fixture_ref("forcep_units")),
            (fixture_ref("full_data_emg_list"), fixture_ref("emg_units")),
            (fixture_ref("full_data_traj"), fixture_ref("traj_units")),
        ],
    )
    def test_correct_units(self, devices, exp_units):
        for dev in devices:
            loaded_units = dev.units
            assert tuple(loaded_units) == tuple(exp_units)

    def test_traj_sampling_freq(self, full_data_traj):
        for dev in full_data_traj:
            assert dev.sampling_frequency == 100

    def test_forces_emg_sampling_freq(self, full_data_forces_emg):
        for dev in full_data_forces_emg:
            assert dev.sampling_frequency == 2000

    @parametrize(
        "devices, exp_shape",
        [
            (fixture_ref("full_data_forcep"),
             fixture_ref("full_data_forcep_shape")),
            (fixture_ref("full_data_emg_list"),
             fixture_ref("full_data_emg_shape")),
            (fixture_ref("full_data_traj"),
             fixture_ref("full_data_traj_shape")),
        ],
    )
    def test_traj_data_shape(self, devices, exp_shape):
        for dev in devices:
            assert dev.df.shape == exp_shape

    def test_col_average_traj(self, full_data_angelica_hv,
                              angelica_hv_average):
        datafr = full_data_angelica_hv.df
        exp_x, exp_y, exp_z = angelica_hv_average
        mean_x = datafr["X"].mean()
        assert np.isclose(mean_x, exp_x)
        mean_y = datafr["Y"].mean()
        assert np.isclose(mean_y, exp_y)
        mean_z = datafr["Z"].mean()
        assert np.isclose(mean_z, exp_z)

    def test_col_average_forcepl_last_5000(self, full_data_forcepl_2,
                                           forcepl2_average):
        datafr = full_data_forcepl_2.df
        for (col, exp_average) in zip(datafr, forcepl2_average):
            last_5000 = datafr[col].iloc[-5000:]
            mean = last_5000.mean()
            assert np.isclose(mean, exp_average)
from pytest_cases import fixture_ref, pytest_parametrize_plus, pytest_fixture_plus


@pytest_fixture_plus
@pytest_parametrize_plus("variant", ['A', 'B'])
def book1(variant):
    return variant


@pytest.fixture
def book2():
    return


@pytest_parametrize_plus("name", [
    fixture_ref(book1),
    'hi',
    'ih',
    fixture_ref(book2),
])
def test_get_or_create_book(name):
    print(name)


def test_synthesis(module_results_dct):
    assert list(module_results_dct) == [
        'test_get_or_create_book[name_is_book1-A]',
        'test_get_or_create_book[name_is_book1-B]',
        'test_get_or_create_book[name_is_P1toP2-hi]',
        'test_get_or_create_book[name_is_P1toP2-ih]',
        'test_get_or_create_book[name_is_book2]'
示例#20
0
    @pytest.fixture
    @saved_fixture
    def a():
        return 'a'

    @pytest_fixture_plus
    @saved_fixture
    @pytest.mark.parametrize('i', [5, 6])
    def b(i):
        return 'b%s' % i

    @parametrize_plus(
        'arg',
        [pytest.param('c'),
         pytest.param(fixture_ref(a)),
         fixture_ref(b)],
        hook=saved_fixture)
    def test_fixture_ref1(arg):
        assert arg in ['a', 'b5', 'b6', 'c']

    def test_synthesis1(request, fixture_store):
        results_dct1 = get_session_synthesis_dct(request,
                                                 filter=test_fixture_ref1,
                                                 test_id_format='function',
                                                 fixture_store=fixture_store,
                                                 flatten=True)
        assert [(k, v['test_fixture_ref1_arg'])
                for k, v in results_dct1.items()] == [
                    ('test_fixture_ref1[arg_is_c]', 'c'),
                    ('test_fixture_ref1[arg_is_a]', 'a'),

@fixture_plus
@parametrize_plus("i", [5, 7])
def bfix(i):
    return -i


def val():
    return 1


has_pytest_param = hasattr(pytest, 'param')
if not has_pytest_param:
    @parametrize_plus("a", [lazy_value(val),
                            fixture_ref(bfix),
                            lazy_value(val, id='A')])
    def test_foo_single(a):
        """here the fixture is used for both parameters at the same time"""
        assert a in (1, -5, -7)


    def test_synthesis2(module_results_dct):
        assert list(module_results_dct) == ['test_foo_single[a_is_val]',
                                            'test_foo_single[a_is_bfix-5]',
                                            'test_foo_single[a_is_bfix-7]',
                                            'test_foo_single[a_is_A]',
                                            ]


else:
示例#22
0
        print("tearing down dataset %s" % dataset_name)

    return dataset


def create_data_from_dataset_fixture(dataset_name):
    @pytest_fixture_plus(name="data_from_%s" % dataset_name, scope="module")
    @pytest.mark.parametrize('data_index',
                             dataset_indices,
                             ids="idx={}".format)
    @with_signature("(%s, data_index)" % dataset_name)
    def data_from_dataset(data_index, **kwargs):
        dataset = kwargs.popitem()[1]
        return dataset[data_index]

    return data_from_dataset


for dataset_name, dataset_indices in datasets_indices.items():
    globals()[dataset_name] = create_dataset_fixture(dataset_name)
    globals()["data_from_%s" %
              dataset_name] = create_data_from_dataset_fixture(dataset_name)


# ------ Test
@pytest_parametrize_plus(
    'data', [fixture_ref('data_from_%s' % n) for n in datasets_indices.keys()])
def test_databases(data):
    # do test
    print(data)
    return 15, 2


def valtuple_only_right_when_lazy():
    global flag
    if flag:
        return 0, -1
    else:
        raise ValueError("not yet ready ! you should call me later ")


has_pytest_param = hasattr(pytest, 'param')
if not has_pytest_param:
    @parametrize_plus("a,b", [lazy_value(valtuple),
                              lazy_value(valtuple, id='A'),
                              fixture_ref(tfix),
                              (fixture_ref(vfix), lazy_value(val)),
                              (lazy_value(val, id='B'), fixture_ref(vfix)),
                              (fixture_ref(vfix), fixture_ref(vfix)),
                              ], debug=True)
    def test_foo_multi(a, b):
        """here the fixture is used for both parameters at the same time"""
        global flag
        flag = True
        assert (a, b) in ((1, 2), (1, 1), (1, 3), (-5, 1), (11, 13), (-1, 1), (1, -5), (1, -1), (-5, -5), (-1, -1))


    def test_synthesis2(module_results_dct):
        assert list(module_results_dct) == ['test_foo_multi[a_b_is_P0toP1-valtuple]',
                                            'test_foo_multi[a_b_is_P0toP1-A]',
                                            'test_foo_multi[a_b_is_tfix-val]',
示例#24
0
from pytest_cases import parametrize_plus, pytest_fixture_plus, fixture_ref


@pytest.fixture
def a():
    return 'A', 'AA'


@pytest_fixture_plus
@pytest.mark.parametrize('arg', [1, 2])
def b(arg):
    return "B%s" % arg


@parametrize_plus("arg1,arg2", [('1', None), (None, '2'),
                                fixture_ref('a'), ('4', '4'),
                                ('3', fixture_ref('b'))])
def test_foo(arg1, arg2):
    print(arg1, arg2)


def test_synthesis(module_results_dct):
    """See https://github.com/smarie/python-pytest-cases/issues/86"""
    assert list(module_results_dct) == [
        'test_foo[arg1_arg2_is_P0toP1-1-None]',
        'test_foo[arg1_arg2_is_P0toP1-None-2]', 'test_foo[arg1_arg2_is_a]',
        'test_foo[arg1_arg2_is_4-4]', 'test_foo[arg1_arg2_is_P4-1]',
        'test_foo[arg1_arg2_is_P4-2]'
    ]
示例#25
0
@pytest.fixture
def mock_open_workbook(mocker):
    open_workbook = mocker.patch("xlrd.open_workbook")
    return open_workbook


@pytest.fixture
def read_excel(mock_open_workbook):
    def actual_read(**kwargs):
        return FitData.read_from_excel(filepath, sheet_name, **kwargs)

    return actual_read, dict(reader=mock_open_workbook, row_setter=set_excel_rows)


@parametrize_plus("read, mocks", [fixture_ref(read_csv), fixture_ref(read_excel)])
def test_read_with_headers_successful(read, mocks):
    mocks["row_setter"](mocks["reader"], ROWS)

    actual_fit_data = read()

    check_data_by_keys(actual_fit_data)
    check_columns(actual_fit_data)


@parametrize_plus("read, mocks", [fixture_ref(read_csv), fixture_ref(read_excel)])
def test_read_without_headers_successful(read, mocks):
    mocks["row_setter"](mocks["reader"], CONTENT)

    actual_fit_data = read()
# pytest.param is not available in all versions
if has_pytest_param:

    @pytest.fixture
    def a():
        return 'a'

    @pytest.fixture
    def b():
        return 'b'

    @parametrize('arg1,arg2', [
        pytest.param("a", 1, id="testID"),
        ("b", 1),
        (fixture_ref(b), 1),
        pytest.param("c", 1, id="testID3"),
        pytest.param(fixture_ref(b), 1, id="testID4"),
        ("c", 1),
    ],
                 debug=True)
    def test_id_tuple(arg1, arg2):
        assert arg1 in ['a', 'b', 'c'] and arg2 == 1

    def test_synthesis(module_results_dct):
        # make sure the id and skip mark were taken into account
        assert list(module_results_dct) == [
            'test_id_tuple[testID]',
            'test_id_tuple[b-1]',
            'test_id_tuple[testID3]',
            'test_id_tuple[testID4]',
示例#27
0
    assert env._total_episode_reward == [0 for _ in range(env.n_agents)]
    assert env._agent_dones == [False for _ in range(env.n_agents)]


def test_reset_after_episode_end(env):
    env.reset()
    done = [False for _ in range(env.n_agents)]
    step_i = 0
    ep_reward = [0 for _ in range(env.n_agents)]
    while not all(done):
        step_i += 1
        _, reward_n, done, _ = env.step(env.action_space.sample())
        for i in range(env.n_agents):
            ep_reward[i] += reward_n[i]

    assert step_i == env._step_count
    assert env._total_episode_reward == ep_reward
    test_reset(env)


@pytest_parametrize_plus('env', [fixture_ref(env)])
def test_observation_space(env):
    obs = env.reset()
    assert env.observation_space.contains(obs)
    done = [False for _ in range(env.n_agents)]
    while not all(done):
        obs, reward_n, done, _ = env.step(env.action_space.sample())
        assert env.observation_space.contains(obs)
    assert env.observation_space.contains(obs)
    assert env.observation_space.contains(env.observation_space.sample())
    return fixture_fun


@fixture_plus(hook=my_hook)
def foo():
    return 2, 1


o, p = unpack_fixture('o,p', foo, hook=my_hook)

p1 = param_fixture("p1", [1, 2], hook=my_hook)

p2, p3 = param_fixtures("p2,p3", [(3, 4)], hook=my_hook)

u = fixture_union("u", (o, p), hook=my_hook)


@parametrize_plus("arg", [fixture_ref(u), fixture_ref(p1)])
def test_a(arg, p2, p3):
    pass


def test_synthesis(module_results_dct):
    assert list(module_results_dct) == [
        'test_a[arg_is_u-u_is_o-3-4]', 'test_a[arg_is_u-u_is_p-3-4]',
        'test_a[arg_is_p1-1-3-4]', 'test_a[arg_is_p1-2-3-4]'
    ]
    assert f_list == [
        'foo', 'o', 'p', 'p1', 'p2_p3__param_fixtures_root', 'p2', 'p3', 'u'
    ]
示例#29
0
    assert env._total_episode_reward == reward_n, 'Total Episode reward doesn\'t match with one step reward'


def test_reset_after_episode_end(env):
    env.reset()
    done = [False for _ in range(env.n_agents)]
    step_i = 0
    ep_reward = [0 for _ in range(env.n_agents)]
    while not all(done):
        step_i += 1
        _, reward_n, done, _ = env.step(env.action_space.sample())
        for i in range(env.n_agents):
            ep_reward[i] += reward_n[i]

    assert step_i == env._step_count
    assert env._total_episode_reward == ep_reward
    test_reset(env)


@parametrize_plus('env', [fixture_ref(env),
                          fixture_ref(env_full)])
def test_observation_space(env):
    obs = env.reset()
    assert env.observation_space.contains(obs)
    done = [False for _ in range(env.n_agents)]
    while not all(done):
        obs, reward_n, done, _ = env.step(env.action_space.sample())
        assert env.observation_space.contains(obs)
    assert env.observation_space.contains(obs)
    assert env.observation_space.contains(env.observation_space.sample())
示例#30
0
# Authors: Sylvain MARIE <*****@*****.**>
#          + All contributors to <https://github.com/smarie/python-pytest-cases>
#
# License: 3-clause BSD, <https://github.com/smarie/python-pytest-cases/blob/master/LICENSE>
import pytest

from pytest_cases import parametrize, fixture, fixture_ref


@pytest.fixture
def a():
    return 'a'


@fixture
@parametrize('second_letter', [fixture_ref('a'), 'o'])
def b(second_letter):
    # second_letter = 'a'
    return 'b' + second_letter


@parametrize('arg', ['z', fixture_ref(a), fixture_ref(b), 'o'])
@pytest.mark.parametrize('bar', ['bar'])
def test_foo(arg, bar):
    assert bar == 'bar'
    assert arg in ['z', 'a', 'ba', 'bo', 'o']


def test_synthesis(module_results_dct):
    assert list(module_results_dct) == [
        'test_foo[z-bar]', 'test_foo[a-bar]', 'test_foo[b-a-bar]',