コード例 #1
0
def test_override_default_args():
    def fn(a, b=2):
        return a, b

    assert fn(1) == (1, 2)
    assert override_default_kwargs(fn, {})(1) == (1, 2)

    fn2 = override_default_kwargs(fn, {"b": 4})
    assert fn2(1) == (1, 4)
    assert fn2(1, 3) == (1, 3)
    # original function unchangd
    assert fn(1) == (1, 2)

    class A(object):
        def __init__(self, a, b=2):
            self.a = a
            self.b = b

        def get_values(self):
            return self.a, self.b

    assert A(1).get_values() == (1, 2)
    B = override_default_kwargs(A, dict(b=4))
    assert B(1).get_values() == (1, 4)
    # original class unchangd
    assert A(1).get_values() == (1, 2)
    assert B(1, 3).get_values() == (1, 3)

    with pytest.raises(ValueError):
        override_default_kwargs(A, dict(c=4))
コード例 #2
0
ファイル: kipoimodeldescription.py プロジェクト: kipoi/kipoi
    def get(self):
        """Get the dataloader
        """

        from kipoi.data import BaseDataLoader
        from copy import deepcopy
        from kipoi_utils.external.related.fields import UNSPECIFIED

        obj = load_obj(self.defined_as)

        # check that it inherits from BaseDataLoader
        if not inherits_from(obj, BaseDataLoader):
            raise ValueError(f"Dataloader: {self.defined_as} doen't inherit from kipoi.data.BaseDataLoader")

        # override the default arguments
        if self.default_args:
            obj = override_default_kwargs(obj, self.default_args)

        # override also the values in the example in case
        # they were previously specified
        # TODO: How to modify this code with KipoiDataLoaderImport in mind?
        for k, v in self.default_args.items():            
            if not isinstance(obj.args[k].example, UNSPECIFIED):
                obj.args[k].example = v

        return obj
コード例 #3
0
    def get(self):
        """Get the dataloader
        """
        from kipoi.data import BaseDataLoader
        from copy import deepcopy
        obj = load_obj(self.defined_as)

        # check that it inherits from BaseDataLoader
        if not inherits_from(obj, BaseDataLoader):
            raise ValueError(
                "Dataloader: {} doen't inherit from kipoi.data.BaseDataLoader".
                format(self.defined_as))

        # override the default arguments
        if self.default_args:
            obj = override_default_kwargs(obj, self.default_args)

        # override also the values in the example in case
        # they were previously specified
        for k, v in self.default_args.items():

            if not isinstance(obj.args[k].example, UNSPECIFIED):
                obj.args[k].example = v

        return obj
コード例 #4
0
def test_override_multiple_default_args():
    def fn(a, b, c, d=4, e=5):
        return a, b, c, d, e

    assert fn(1, 2, 3) == (1, 2, 3, 4, 5)
    assert override_default_kwargs(fn, {})(1, 2, 3, 4, 5) == (1, 2, 3, 4, 5)

    fn2 = override_default_kwargs(fn, {"d": 20})
    assert fn2(1, 2, 3) == (1, 2, 3, 20, 5)
    assert fn2(1, 2, 3, 10, 20) == (1, 2, 3, 10, 20)
    # original function unchangd
    assert fn(1, 2, 3) == (1, 2, 3, 4, 5)

    with pytest.raises(TypeError) as e:
        fn2(1)
    assert "fn() missing 2 required positional arguments: 'b' and 'c'" in str(
        e.value)

    class A(object):
        def __init__(self, a, b, c, d=4, e=5):
            self.a = a
            self.b = b
            self.c = c
            self.d = d
            self.e = e

        def get_values(self):
            return self.a, self.b, self.c, self.d, self.e

    with pytest.raises(TypeError) as e:
        A(1).get_values()
    assert "__init__() missing 2 required positional arguments: 'b' and 'c'" in str(
        e.value)

    B = override_default_kwargs(A, dict(b=10, c=20))
    assert B(1).get_values() == (1, 10, 20, 4, 5)
    # original class unchangd
    with pytest.raises(TypeError) as e:
        A(1).get_values()
    assert "__init__() missing 2 required positional arguments: 'b' and 'c'" in str(
        e.value)
    assert B(1, 30, 40).get_values() == (1, 30, 40, 4, 5)
    with pytest.raises(ValueError) as err:
        override_default_kwargs(A, dict(f=8))
    assert str(
        err.value
    ) == "argument 'f' not specified in function/class.__init__ <class 'kipoi_utils.utils.Overriddentest_override_multiple_default_args.<locals>.A'> with args: ['a', 'b', 'c', 'd', 'e']"