def test_override_default_args(): def fn(a, b=2): return a, b assert fn(1) == (1, 2) assert override_default_kwargs(fn, {})(1) == (1, 2) fn2 = override_default_kwargs(fn, {"b": 4}) assert fn2(1) == (1, 4) assert fn2(1, 3) == (1, 3) # original function unchangd assert fn(1) == (1, 2) class A(object): def __init__(self, a, b=2): self.a = a self.b = b def get_values(self): return self.a, self.b assert A(1).get_values() == (1, 2) B = override_default_kwargs(A, dict(b=4)) assert B(1).get_values() == (1, 4) # original class unchangd assert A(1).get_values() == (1, 2) assert B(1, 3).get_values() == (1, 3) with pytest.raises(ValueError): override_default_kwargs(A, dict(c=4))
def test_output_schape(): Dl = deepcopy(SeqIntervalDl) assert Dl.get_output_schema().inputs.shape == (None, 4) Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100}) assert Dlc.get_output_schema().inputs.shape == (100, 4) # original left intact assert Dl.get_output_schema().inputs.shape == (None, 4) Dlc = override_default_kwargs(Dl, { "auto_resize_len": 100, "dummy_axis": 1, "alphabet_axis": 2 }) assert Dlc.get_output_schema().inputs.shape == (100, 1, 4) Dlc = override_default_kwargs(Dl, { "auto_resize_len": 100, "dummy_axis": 2 }) assert Dlc.get_output_schema().inputs.shape == (100, 4, 1) # original left intact assert Dl.get_output_schema().inputs.shape == (None, 4) Dlc = override_default_kwargs(Dl, { "auto_resize_len": 100, "alphabet": "ACGTD" }) assert Dlc.get_output_schema().inputs.shape == (100, 5) Dlc = override_default_kwargs(Dl, { "auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 0 }) assert Dlc.get_output_schema().inputs.shape == (4, 160, 1) Dlc = override_default_kwargs(Dl, { "auto_resize_len": 160, "dummy_axis": 2, "alphabet_axis": 1 }) assert Dlc.get_output_schema().inputs.shape == (160, 4, 1) targets = Dlc.get_output_schema().targets assert targets.shape == (None, ) Dlc = override_default_kwargs(Dl, {"ignore_targets": True}) assert Dlc.get_output_schema().targets is None # reset back # original left intact assert Dl.get_output_schema().inputs.shape == (None, 4) assert Dl.get_output_schema().targets.shape == (None, )
def get(self): """Get the dataloader """ from kipoi.data import BaseDataLoader from copy import deepcopy obj = load_obj(self.defined_as) # check that it inherits from BaseDataLoader if not inherits_from(obj, BaseDataLoader): raise ValueError("Dataloader: {} doen't inherit from kipoi.data.BaseDataLoader".format(self.defined_as)) # override the default arguments if self.default_args: obj = override_default_kwargs(obj, self.default_args) # override also the values in the example in case # they were previously specified for k, v in six.iteritems(self.default_args): if not isinstance(obj.args[k].example, UNSPECIFIED): obj.args[k].example = v return obj