def multiple_parameters(*param_lists): param_lists = [[ params if isinstance(params, tuple) else (params, ) for params in param_list ] for param_list in param_lists] result = param_lists[0] for param_list in param_lists[1:]: result = [(*args1, *args2) for args1 in result for args2 in param_list] return parameterized.parameters(*result)
def wrapper(fn): """Constructs and returns the arguments as a dictionary.""" arg_names = inspect.getargspec(fn).args if arg_names[0] != 'self': raise ValueError( 'First argument to test is expected to be "self", but is {}'.format( arg_names[0])) arg_names = arg_names[1:] def to_arg_dict(testcase): testcase = tuple(testcase) if len(testcase) != len(arg_names): raise ValueError( 'The number of arguments to parameterized test do not match the ' 'number of expected arguments: {} != {}, arguments: {}, names: {}'. format(len(testcase), len(arg_names), testcase, arg_names)) return dict(zip(arg_names, testcase)) testcases_with_names = [to_arg_dict(testcase) for testcase in testcases] return parameterized.parameters(*testcases_with_names)(fn)
import collections import copy import pickle import threading from absl.testing import absltest from absl.testing import parameterized import cloudpickle import dill from haiku._src import data_structures import jax import tree frozendict = data_structures.frozendict FlatMap = data_structures.FlatMap all_picklers = parameterized.parameters(cloudpickle, dill, pickle) class StackTest(absltest.TestCase): cls = data_structures.Stack def test_len(self): s = self.cls() self.assertEmpty(s) for i in range(10): self.assertLen(s, i) s.push(None) for i in range(10): self.assertLen(s, 10 - i) s.pop()