def test_override_default_args(): def fn(a, b=2): return a, b assert fn(1) == (1, 2) assert override_default_kwargs(fn, {})(1) == (1, 2) fn2 = override_default_kwargs(fn, {"b": 4}) assert fn2(1) == (1, 4) assert fn2(1, 3) == (1, 3) # original function unchangd assert fn(1) == (1, 2) class A(object): def __init__(self, a, b=2): self.a = a self.b = b def get_values(self): return self.a, self.b assert A(1).get_values() == (1, 2) B = override_default_kwargs(A, dict(b=4)) assert B(1).get_values() == (1, 4) # original class unchangd assert A(1).get_values() == (1, 2) assert B(1, 3).get_values() == (1, 3) with pytest.raises(ValueError): override_default_kwargs(A, dict(c=4))
def get(self): """Get the dataloader """ from kipoi.data import BaseDataLoader from copy import deepcopy from kipoi_utils.external.related.fields import UNSPECIFIED obj = load_obj(self.defined_as) # check that it inherits from BaseDataLoader if not inherits_from(obj, BaseDataLoader): raise ValueError(f"Dataloader: {self.defined_as} doen't inherit from kipoi.data.BaseDataLoader") # override the default arguments if self.default_args: obj = override_default_kwargs(obj, self.default_args) # override also the values in the example in case # they were previously specified # TODO: How to modify this code with KipoiDataLoaderImport in mind? for k, v in self.default_args.items(): if not isinstance(obj.args[k].example, UNSPECIFIED): obj.args[k].example = v return obj
def get(self): """Get the dataloader """ from kipoi.data import BaseDataLoader from copy import deepcopy obj = load_obj(self.defined_as) # check that it inherits from BaseDataLoader if not inherits_from(obj, BaseDataLoader): raise ValueError( "Dataloader: {} doen't inherit from kipoi.data.BaseDataLoader". format(self.defined_as)) # override the default arguments if self.default_args: obj = override_default_kwargs(obj, self.default_args) # override also the values in the example in case # they were previously specified for k, v in self.default_args.items(): if not isinstance(obj.args[k].example, UNSPECIFIED): obj.args[k].example = v return obj
def test_override_multiple_default_args(): def fn(a, b, c, d=4, e=5): return a, b, c, d, e assert fn(1, 2, 3) == (1, 2, 3, 4, 5) assert override_default_kwargs(fn, {})(1, 2, 3, 4, 5) == (1, 2, 3, 4, 5) fn2 = override_default_kwargs(fn, {"d": 20}) assert fn2(1, 2, 3) == (1, 2, 3, 20, 5) assert fn2(1, 2, 3, 10, 20) == (1, 2, 3, 10, 20) # original function unchangd assert fn(1, 2, 3) == (1, 2, 3, 4, 5) with pytest.raises(TypeError) as e: fn2(1) assert "fn() missing 2 required positional arguments: 'b' and 'c'" in str( e.value) class A(object): def __init__(self, a, b, c, d=4, e=5): self.a = a self.b = b self.c = c self.d = d self.e = e def get_values(self): return self.a, self.b, self.c, self.d, self.e with pytest.raises(TypeError) as e: A(1).get_values() assert "__init__() missing 2 required positional arguments: 'b' and 'c'" in str( e.value) B = override_default_kwargs(A, dict(b=10, c=20)) assert B(1).get_values() == (1, 10, 20, 4, 5) # original class unchangd with pytest.raises(TypeError) as e: A(1).get_values() assert "__init__() missing 2 required positional arguments: 'b' and 'c'" in str( e.value) assert B(1, 30, 40).get_values() == (1, 30, 40, 4, 5) with pytest.raises(ValueError) as err: override_default_kwargs(A, dict(f=8)) assert str( err.value ) == "argument 'f' not specified in function/class.__init__ <class 'kipoi_utils.utils.Overriddentest_override_multiple_default_args.<locals>.A'> with args: ['a', 'b', 'c', 'd', 'e']"