def nested_params(*params_set):
    """Generate the cartesian product of the given list of parameters.

    Args:
        params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
            all the parameters have to be specified with the class, only using kwargs.
    """
    flatten = [p for params in params_set for p in params]

    # Parameters to be nested are given as list of plain objects
    if all(not isinstance(p, param) for p in flatten):
        args = list(product(*params_set))
        return parameterized.expand(args, name_func=_name_func)

    # Parameters to be nested are given as list of `parameterized.param`
    if not all(isinstance(p, param) for p in flatten):
        raise TypeError(
            "When using ``parameterized.param``, "
            "all the parameters have to be of the ``param`` type.")
    if any(p.args for p in flatten):
        raise ValueError(
            "When using ``parameterized.param``, "
            "all the parameters have to be provided as keyword argument."
        )
    args = [param()]
    for params in params_set:
        args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
    return parameterized.expand(args)
Beispiel #2
0
    def _outer_wrapper(wrapped_function):
        import inspect
        import pytest
        from ndspy.rom import NintendoDSRom
        from unittest import SkipTest
        from parameterized import parameterized
        rom = None
        if 'SKYTEMPLE_TEST_ROM' in os.environ and os.environ[
                'SKYTEMPLE_TEST_ROM'] != '':
            rom = NintendoDSRom.fromFile(os.environ['SKYTEMPLE_TEST_ROM'])

        if rom:

            def dataset_name_func(testcase_func, _, param):
                return f'{testcase_func.__name__}/{param.args[0]}'

            files = [(x, rom.getFileByName(x))
                     for x in get_files_from_rom_with_extension(rom, file_ext)
                     if x.startswith(path)]

            if len(files) < 1:

                def no_files(*args, **kwargs):
                    raise SkipTest("No matching files were found in the ROM.")

                return pytest.mark.romtest(no_files)
            else:
                spec = inspect.getfullargspec(wrapped_function)
                if "pmd2_data" in spec.args or "pmd2_data" in spec.kwonlyargs:
                    pmd2_data = get_ppmdu_config_for_rom(rom)

                    def pmd2datawrapper(*args, **kwargs):
                        return wrapped_function(*args,
                                                **kwargs,
                                                pmd2_data=pmd2_data)

                    pmd2datawrapper.__name__ = wrapped_function.__name__

                    parameterized.expand(files, name_func=dataset_name_func)(
                        pytest.mark.romtest(pmd2datawrapper))
                else:
                    parameterized.expand(files, name_func=dataset_name_func)(
                        pytest.mark.romtest(wrapped_function))
                # since expands now adds the tests to our locals, we need to pass them back...
                # this isn't hacky at all wdym??????ßßß
                frame_locals = inspect.currentframe(
                ).f_back.f_locals  # type: ignore
                for local_name, local in inspect.currentframe().f_locals.items(
                ):  # type: ignore
                    if local_name.startswith('test_'):
                        frame_locals[local_name] = local

        else:

            def no_tests(*args, **kwargs):
                raise SkipTest("No ROM file provided or ROM not found.")

            return pytest.mark.romtest(no_tests)
Beispiel #3
0
def _expand(generation: int):
    multiplier = abs(generation) + 1 if generation < 0 else 1
    threshold_year = datetime.now().year - 100 * multiplier
    date_under_threshold = Date(threshold_year + 1, 1, 1)
    date_over_threshold = Date(threshold_year - 1, 1, 1)
    return parameterized.expand([
        # If there are no events for a person, their privacy does not change.
        (True, None, None),
        (True, True, None),
        (False, False, None),
        # Deaths are special, and their existence prevents generation 0 from being private even without a date.
        (generation != 0, None, Event('E0', Event.Type.DEATH)),
        (True, True, Event('E0', Event.Type.DEATH)),
        (False, False, Event('E0', Event.Type.DEATH)),
        # Regular events without dates do not affect privacy.
        (True, None, Event('E0', Event.Type.BIRTH)),
        (True, True, Event('E0', Event.Type.BIRTH)),
        (False, False, Event('E0', Event.Type.BIRTH)),
        # Regular events with incomplete dates do not affect privacy.
        (True, None, Event('E0', Event.Type.BIRTH, date=Date())),
        (True, True, Event('E0', Event.Type.BIRTH, date=Date())),
        (False, False, Event('E0', Event.Type.BIRTH, date=Date())),
        # Regular events under the lifetime threshold do not affect privacy.
        (True, None, Event('E0', Event.Type.BIRTH, date=date_under_threshold)),
        (True, True, Event('E0', Event.Type.BIRTH, date=date_under_threshold)),
        (False, False, Event('E0', Event.Type.BIRTH,
                             date=date_under_threshold)),
        # Regular events over the lifetime threshold affect privacy.
        (False, None, Event('E0', Event.Type.BIRTH, date=date_over_threshold)),
        (True, True, Event('E0', Event.Type.BIRTH, date=date_over_threshold)),
        (False, False, Event('E0', Event.Type.BIRTH,
                             date=date_over_threshold)),
    ])
Beispiel #4
0
def _expand(generation: int):
    multiplier = abs(generation) + 1 if generation < 0 else 1
    threshold_year = datetime.now().year - 100 * multiplier
    date_under_threshold = Date(threshold_year + 1, 1, 1)
    date_range_start_under_threshold = DateRange(date_under_threshold)
    date_range_end_under_threshold = DateRange(None, date_under_threshold)
    date_over_threshold = Date(threshold_year - 1, 1, 1)
    date_range_start_over_threshold = DateRange(date_over_threshold)
    date_range_end_over_threshold = DateRange(None, date_over_threshold)
    return parameterized.expand([
        # If there are no events for a person, their privacy does not change.
        (True, None, None),
        (True, True, None),
        (False, False, None),
        # Deaths and burials are special, and their existence prevents generation 0 from being private even without
        # having passed the usual threshold.
        (generation != 0, None, IdentifiableEvent('E0', Event.Type.DEATH, date=Date(datetime.now().year, datetime.now().month, datetime.now().day))),
        (generation != 0, None, IdentifiableEvent('E0', Event.Type.DEATH, date=date_under_threshold)),
        (True, None, IdentifiableEvent('E0', Event.Type.DEATH, date=date_range_start_under_threshold)),
        (generation != 0, None, IdentifiableEvent('E0', Event.Type.DEATH, date=date_range_end_under_threshold)),
        (True, True, IdentifiableEvent('E0', Event.Type.DEATH)),
        (False, False, IdentifiableEvent('E0', Event.Type.DEATH)),
        (generation != 0, None, IdentifiableEvent('E0', Event.Type.BURIAL, date=Date(datetime.now().year, datetime.now().month, datetime.now().day))),
        (generation != 0, None, IdentifiableEvent('E0', Event.Type.BURIAL, date=date_under_threshold)),
        (True, None, IdentifiableEvent('E0', Event.Type.BURIAL, date=date_range_start_under_threshold)),
        (generation != 0, None, IdentifiableEvent('E0', Event.Type.BURIAL, date=date_range_end_under_threshold)),
        (True, True, IdentifiableEvent('E0', Event.Type.BURIAL)),
        (False, False, IdentifiableEvent('E0', Event.Type.BURIAL)),
        # Regular events without dates do not affect privacy.
        (True, None, IdentifiableEvent('E0', Event.Type.BIRTH)),
        (True, True, IdentifiableEvent('E0', Event.Type.BIRTH)),
        (False, False, IdentifiableEvent('E0', Event.Type.BIRTH)),
        # Regular events with incomplete dates do not affect privacy.
        (True, None, IdentifiableEvent('E0', Event.Type.BIRTH, date=Date())),
        (True, True, IdentifiableEvent('E0', Event.Type.BIRTH, date=Date())),
        (False, False, IdentifiableEvent('E0', Event.Type.BIRTH, date=Date())),
        # Regular events under the lifetime threshold do not affect privacy.
        (True, None, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_under_threshold)),
        (True, True, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_under_threshold)),
        (False, False, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_under_threshold)),
        (True, None, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_start_under_threshold)),
        (True, True, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_start_under_threshold)),
        (False, False, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_start_under_threshold)),
        (True, None, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_end_under_threshold)),
        (True, True, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_end_under_threshold)),
        (False, False, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_end_under_threshold)),
        # Regular events over the lifetime threshold affect privacy.
        (False, None, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_over_threshold)),
        (True, True, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_over_threshold)),
        (False, False, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_over_threshold)),
        (False, None, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_start_over_threshold)),
        (True, True, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_start_over_threshold)),
        (False, False, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_start_over_threshold)),
        (False, None, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_end_over_threshold)),
        (True, True, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_end_over_threshold)),
        (False, False, IdentifiableEvent('E0', Event.Type.BIRTH, date=date_range_end_over_threshold)),
    ])
Beispiel #5
0
def nested_params(*params):
    def _name_func(func, _, params):
        strs = []
        for arg in params.args:
            if isinstance(arg, tuple):
                strs.append("_".join(str(a) for a in arg))
            else:
                strs.append(str(arg))
        return f'{func.__name__}_{"_".join(strs)}'

    return parameterized.expand(list(product(*params)), name_func=_name_func)
def exist_perms(**kwargs):
    all_elems = list(kwargs.keys())
    curr_elems = copy.deepcopy(all_elems)

    perms = _perms_cycle(curr_elems.pop(), curr_elems, {})
    res = []
    for name_str, perm in perms:
        args = dict([(k, v) for (k, v) in kwargs.items() if perm[k]])
        res.append((name_str, args))

    return parameterized.expand(res)
def exist_perms(**kwargs):
    all_elems = list(kwargs.keys())
    curr_elems = copy.deepcopy(all_elems)

    perms = _perms_cycle(curr_elems.pop(), curr_elems, {})
    res = []
    for name_str, perm in perms:
        args = dict([(k, v) for (k, v) in kwargs.items() if perm[k]])
        res.append((name_str, args))

    return parameterized.expand(res)
Beispiel #8
0
def parametric_suite(*args, **kwargs):
    """
    Decorator used for testing a range of different options for a particular
    ParametericTestGroup. If args is present, must only be the value '*',
    indicating running all available groups/parameters. Otherwise, use kwargs to set the options like so:
    arg=value will specify that option,
    arg='*' will vary over all default options,
    arg=iterable will iterate over the given options.
    Arguments that are not specified will have a reasonable default chosen."""
    run_by_default = kwargs.pop('run_by_default', False)
    test_cases = _test_suite(*args, **kwargs)
    return parameterized.expand(test_cases,
                                testcase_func_name=_test_name(run_by_default))
Beispiel #9
0
def parametric_suite(*args, **kwargs):
    """
    Decorator used for testing a range of different options for a particular
    ParametericTestGroup. If args is present, must only be the value '*',
    indicating running all available groups/parameters. Otherwise, use kwargs to set the options
    like so:
        arg=value will specify that option,
        arg='*' will vary over all default options,
        arg=iterable will iterate over the given options.
    Arguments that are not specified will have a reasonable default chosen."""
    run_by_default = kwargs.pop('run_by_default', False)
    test_cases = _test_suite(*args, **kwargs)
    return parameterized.expand(test_cases, testcase_func_name=_test_name(run_by_default))
Beispiel #10
0
def test_ms_domain(versions=None):
    """ Parameterize test case to apply ms opset(s) as extra_opset. """
    def _custom_name_func(testcase_func, param_num, param):
        del param_num
        arg = param.args[0]
        return "%s_%s" % (testcase_func.__name__, arg.version)

    # Test all opset versions in ms domain if versions is not specified
    if versions is None:
        versions = list(range(1, _MAX_MS_OPSET_VERSION + 1))

    opsets = []
    for version in versions:
        opsets.append(
            [utils.make_opsetid(constants.MICROSOFT_DOMAIN, version)])
    return parameterized.expand(opsets, testcase_func_name=_custom_name_func)
def reset_and_delete():
    """
    Parametrize a test for both delete and reset operations,
    which should yield the same results.
    NOTE: "parameterized" engages in call stack manipulation,
        so be careful when changing the application of the decorator.
        For example, receiving "func" as a parameter and passing it to
        "expand" doesn't work.
    """
    return parameterized.expand(
        [
            (lambda self, task: self.tasks.delete(task=task),
             ),
            (lambda self, task: self.tasks.reset(task=task),
             ),
        ],
        name_func=lambda func, num, _: "{}_{}".format(
            func.__name__, ["delete", "reset"][int(num)]),
    )
def use_all_backends(except_backends: Tuple[str, ...] = ()) -> Callable:
    """Decorate test functions to make them use specific backends.

    By default, tests decorated with this function use all backends. However, some backends don't
    support certain features, so it's useful to exclude certain backends for individual tests.

    Args:
        except_backends: Tuple[str], optional argument. Tuple of backend strings from
                         test_backend.py to exclude in testing.

    Returns:
        function that expands tests for each non-excluded backend.
    """
    non_excluded_backends = [
        backend for backend in all_backends_list if backend not in except_backends
    ]
    # parameterized.expand() takes in a list of test parameters (in this case, backend strings
    # specifying which backends to use for the test) and auto-generates a test function for each
    # backend. For more information see https://github.com/wolever/parameterized
    return parameterized.expand(non_excluded_backends)
Beispiel #13
0
HF_LARGE = _load_config('facebook', 'wav2vec2-large')
HF_LARGE_LV60 = _load_config('facebook', 'wav2vec2-large-lv60')
HF_LARGE_XLSR_53 = _load_config('facebook', 'wav2vec2-large-xlsr-53')
HF_BASE_10K_VOXPOPULI = _load_config('facebook', 'wav2vec2-base-10k-voxpopuli')
# Finetuned
HF_BASE_960H = _load_config('facebook', 'wav2vec2-base-960h')
HF_LARGE_960H = _load_config('facebook', 'wav2vec2-large-960h')
HF_LARGE_LV60_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60')
HF_LARGE_LV60_SELF_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60-self')
HF_LARGE_XLSR_DE = _load_config('facebook', 'wav2vec2-large-xlsr-53-german')

# Config and corresponding factory functions
PRETRAIN_CONFIGS = parameterized.expand([
    (HF_BASE, wav2vec2_base),
    (HF_LARGE, wav2vec2_large),
    (HF_LARGE_LV60, wav2vec2_large_lv60k),
    (HF_LARGE_XLSR_53, wav2vec2_large_lv60k),
    (HF_BASE_10K_VOXPOPULI, wav2vec2_base),
], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([
    (HF_BASE_960H, wav2vec2_base),
    (HF_LARGE_960H, wav2vec2_large),
    (HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
    (HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k),
    (HF_LARGE_XLSR_DE, wav2vec2_large_lv60k),
], name_func=_name_func)


@skipIfNoModule('transformers')
class TestHFIntegration(TorchaudioTestCase):
    """Test the process of importing the models from Hugging Face Transformers
Beispiel #14
0
def _expand_person(generation: int):
    lifetime_threshold = 125
    multiplier = abs(generation) + 1 if generation < 0 else 1
    lifetime_threshold_year = datetime.now(
    ).year - lifetime_threshold * multiplier
    date_under_lifetime_threshold = Date(lifetime_threshold_year + 1, 1, 1)
    date_range_start_under_lifetime_threshold = DateRange(
        date_under_lifetime_threshold)
    date_range_end_under_lifetime_threshold = DateRange(
        None, date_under_lifetime_threshold)
    date_over_lifetime_threshold = Date(lifetime_threshold_year - 1, 1, 1)
    date_range_start_over_lifetime_threshold = DateRange(
        date_over_lifetime_threshold)
    date_range_end_over_lifetime_threshold = DateRange(
        None, date_over_lifetime_threshold)
    return parameterized.expand([
        # If there are no events for a person, they are private.
        (True, None, None),
        (True, True, None),
        (False, False, None),

        # Deaths and other end-of-life events are special, but only for the person whose privacy is being checked:
        # - If they're present without dates, the person isn't private.
        # - If they're present and their dates or date ranges' end dates are in the past, the person isn't private.
        (generation != 0, None,
         Event(Death(),
               date=Date(datetime.now().year,
                         datetime.now().month,
                         datetime.now().day))),
        (generation != 0, None,
         Event(Death(), date=date_under_lifetime_threshold)),
        (True, None,
         Event(Death(), date=date_range_start_under_lifetime_threshold)),
        (generation != 0, None,
         Event(Death(), date=date_range_end_under_lifetime_threshold)),
        (False, None, Event(Death(), date=date_over_lifetime_threshold)),
        (True, None,
         Event(Death(), date=date_range_start_over_lifetime_threshold)),
        (False, None,
         Event(Death(), date=date_range_end_over_lifetime_threshold)),
        (True, True, Event(Death())),
        (False, False, Event(Death())),
        (generation != 0, None, Event(Death())),

        # Regular events without dates do not affect privacy.
        (True, None, Event(Birth())),
        (True, True, Event(Birth())),
        (False, False, Event(Birth())),

        # Regular events with incomplete dates do not affect privacy.
        (True, None, Event(Birth(), date=Date())),
        (True, True, Event(Birth(), date=Date())),
        (False, False, Event(Birth(), date=Date())),

        # Regular events under the lifetime threshold do not affect privacy.
        (True, None, Event(Birth(), date=date_under_lifetime_threshold)),
        (True, True, Event(Birth(), date=date_under_lifetime_threshold)),
        (False, False, Event(Birth(), date=date_under_lifetime_threshold)),
        (True, None,
         Event(Birth(), date=date_range_start_under_lifetime_threshold)),
        (True, True,
         Event(Birth(), date=date_range_start_under_lifetime_threshold)),
        (False, False,
         Event(Birth(), date=date_range_start_under_lifetime_threshold)),
        (True, None,
         Event(Birth(), date=date_range_end_under_lifetime_threshold)),
        (True, True,
         Event(Birth(), date=date_range_end_under_lifetime_threshold)),
        (False, False,
         Event(Birth(), date=date_range_end_under_lifetime_threshold)),

        # Regular events over the lifetime threshold affect privacy.
        (False, None, Event(Birth(), date=date_over_lifetime_threshold)),
        (True, True, Event(Birth(), date=date_over_lifetime_threshold)),
        (False, False, Event(Birth(), date=date_over_lifetime_threshold)),
        (True, None,
         Event(Birth(), date=date_range_start_over_lifetime_threshold)),
        (True, True,
         Event(Birth(), date=date_range_start_over_lifetime_threshold)),
        (False, False,
         Event(Birth(), date=date_range_start_over_lifetime_threshold)),
        (False, None,
         Event(Birth(), date=date_range_end_over_lifetime_threshold)),
        (True, True, Event(Birth(),
                           date=date_range_end_over_lifetime_threshold)),
        (False, False,
         Event(Birth(), date=date_range_end_over_lifetime_threshold)),
    ])
Beispiel #15
0
 def expand_parameterized_test_export(*args, **kwargs):
     if "name_func" not in kwargs:
         kwargs["name_func"] = get_export_test_name
     return parameterized.expand(*args, **kwargs)
Beispiel #16
0
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
    import torch.ao.quantization as tq
else:
    import torch.quantization as tq


def _name_func(testcase_func, i, param):
    return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"


factory_funcs = parameterized.expand([
    (wav2vec2_base, ),
    (wav2vec2_large, ),
    (wav2vec2_large_lv60k, ),
    (hubert_base, ),
    (hubert_large, ),
    (hubert_xlarge, ),
], name_func=_name_func)


class TestWav2Vec2Model(TorchaudioTestCase):
    def _smoke_test(self, model, device, dtype):
        model = model.to(device=device, dtype=dtype)
        model = model.eval()

        torch.manual_seed(0)
        batch_size, num_frames = 3, 1024

        waveforms = torch.randn(
            batch_size, num_frames, device=device, dtype=dtype)
Beispiel #17
0

class OtherC(Classifier):
    dtype = int64_dtype
    missing_value = -1
    inputs = ()
    window_length = 0


class Mask(Filter):
    inputs = ()
    window_length = 0


for_each_factor_dtype = parameterized.expand([
    ('datetime64[ns]', datetime64ns_dtype),
    ('float', float64_dtype),
])


class FactorTestCase(BasePipelineTestCase):
    def init_instance_fixtures(self):
        super(FactorTestCase, self).init_instance_fixtures()
        self.f = F()

    def test_bad_input(self):
        with self.assertRaises(UnknownRankMethod):
            self.f.rank("not a real rank method")

    @parameter_space(method_name=['isnan', 'notnan', 'isfinite'])
    def test_float64_only_ops(self, method_name):
        class NotFloat(Factor):
def true_false_perms(*all_elems_tuple):
    all_elems = list(all_elems_tuple)
    curr_elems = copy.deepcopy(all_elems)

    perms = _perms_cycle(curr_elems.pop(), curr_elems, {})
    return parameterized.expand(perms)
Beispiel #19
0
def parameterize(*params):
    return parameterized.expand(list(itertools.product(*params)),
                                name_func=name_func)
HUBERT_BASE = _load_config('hubert_base_ls960')
HUBERT_LARGE_LL60K = _load_config('hubert_large_ll60k')
HUBERT_XLARGE_LL60K = _load_config('hubert_xtralarge_ll60k')
# Finetuning models
WAV2VEC2_BASE_960H = _load_config('wav2vec_small_960h')
WAV2VEC2_LARGE_960H = _load_config('wav2vec_large_960h')
WAV2VEC2_LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h')
WAV2VEC2_LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h')
HUBERT_LARGE = _load_config('hubert_large_ll60k_finetune_ls960')
HUBERT_XLARGE = _load_config('hubert_xtralarge_ll60k_finetune_ls960')

# Config and corresponding factory functions
WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand([
    (WAV2VEC2_BASE, wav2vec2_base),
    (WAV2VEC2_LARGE, wav2vec2_large),
    (WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k),
    (WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k),
],
                                                    name_func=_name_func)
HUBERT_PRETRAINING_CONFIGS = parameterized.expand([
    (HUBERT_BASE, hubert_base),
    (HUBERT_LARGE_LL60K, hubert_large),
    (HUBERT_XLARGE_LL60K, hubert_xlarge),
],
                                                  name_func=_name_func)
ALL_PRETRAINING_CONFIGS = parameterized.expand([
    (WAV2VEC2_BASE, wav2vec2_base),
    (WAV2VEC2_LARGE, wav2vec2_large),
    (WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k),
    (WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k),
    (HUBERT_BASE, hubert_base),
Beispiel #21
0
def deterministic_expand(params):
    """Takes params as a list of lambdas where each lambda produces a tuple of
    unique parameters for the test"""
    torch.manual_seed(0)
    np.random.seed(0)
    return parameterized.expand([p() for p in params])
def true_false_perms(*all_elems_tuple):
    all_elems = list(all_elems_tuple)
    curr_elems = copy.deepcopy(all_elems)

    perms = _perms_cycle(curr_elems.pop(), curr_elems, {})
    return parameterized.expand(perms)
Beispiel #23
0
from graphql_compiler.tests import test_backend
from graphql_compiler.tests.test_helpers import generate_schema, generate_schema_graph

from ..test_helpers import SCHEMA_TEXT, compare_ignoring_whitespace, get_schema
from .integration_backend_config import MATCH_BACKENDS, SQL_BACKENDS
from .integration_test_helpers import (compile_and_run_match_query,
                                       compile_and_run_sql_query,
                                       sort_db_results)

# Store the test parametrization for running against all backends. Individual tests can customize
# the list of backends to test against with the full @parametrized.expand([...]) decorator.
all_backends = parameterized.expand([
    test_backend.ORIENTDB,
    test_backend.POSTGRES,
    test_backend.MARIADB,
    test_backend.MYSQL,
    test_backend.SQLITE,
    test_backend.MSSQL,
])

# Store the typical fixtures required for an integration tests.
# Individual tests can supply the full @pytest.mark.usefixtures to override if necessary.
integration_fixtures = pytest.mark.usefixtures(
    'integration_graph_client',
    'sql_integration_data',
)


# The following test class uses several fixtures adding members that pylint
# does not recognize
# pylint: disable=no-member
Beispiel #24
0
from torchaudio.models.wav2vec2 import (
    wav2vec2_base,
    wav2vec2_large,
    wav2vec2_large_lv60k,
)
from torchaudio_unittest.common_utils import (
    TorchaudioTestCase,
    skipIfNoQengine,
    skipIfNoCuda,
)
from parameterized import parameterized

factory_funcs = parameterized.expand([
    (wav2vec2_base, ),
    (wav2vec2_large, ),
    (wav2vec2_large_lv60k, ),
])


class TestWav2Vec2Model(TorchaudioTestCase):
    def _smoke_test(self, device, dtype):
        model = wav2vec2_base(num_out=32)
        model = model.to(device=device, dtype=dtype)
        model = model.eval()

        torch.manual_seed(0)
        batch_size, num_frames = 3, 1024

        waveforms = torch.randn(batch_size,
                                num_frames,
Beispiel #25
0
def argsprod(*args):
    return parameterized.expand([tuple(elem) for elem in itertools.product(*args)])