예제 #1
0
def test_override_default_args():
    def fn(a, b=2):
        return a, b

    assert fn(1) == (1, 2)
    assert override_default_kwargs(fn, {})(1) == (1, 2)

    fn2 = override_default_kwargs(fn, {"b": 4})
    assert fn2(1) == (1, 4)
    assert fn2(1, 3) == (1, 3)
    # original function unchangd
    assert fn(1) == (1, 2)

    class A(object):
        def __init__(self, a, b=2):
            self.a = a
            self.b = b

        def get_values(self):
            return self.a, self.b

    assert A(1).get_values() == (1, 2)
    B = override_default_kwargs(A, dict(b=4))
    assert B(1).get_values() == (1, 4)
    # original class unchangd
    assert A(1).get_values() == (1, 2)
    assert B(1, 3).get_values() == (1, 3)

    with pytest.raises(ValueError):
        override_default_kwargs(A, dict(c=4))
예제 #2
0
def test_output_schape():
    Dl = deepcopy(SeqIntervalDl)
    assert Dl.get_output_schema().inputs.shape == (None, 4)
    Dlc = override_default_kwargs(Dl, {"auto_resize_len": 100})
    assert Dlc.get_output_schema().inputs.shape == (100, 4)

    # original left intact
    assert Dl.get_output_schema().inputs.shape == (None, 4)

    Dlc = override_default_kwargs(Dl, {
        "auto_resize_len": 100,
        "dummy_axis": 1,
        "alphabet_axis": 2
    })
    assert Dlc.get_output_schema().inputs.shape == (100, 1, 4)
    Dlc = override_default_kwargs(Dl, {
        "auto_resize_len": 100,
        "dummy_axis": 2
    })
    assert Dlc.get_output_schema().inputs.shape == (100, 4, 1)
    # original left intact
    assert Dl.get_output_schema().inputs.shape == (None, 4)

    Dlc = override_default_kwargs(Dl, {
        "auto_resize_len": 100,
        "alphabet": "ACGTD"
    })
    assert Dlc.get_output_schema().inputs.shape == (100, 5)

    Dlc = override_default_kwargs(Dl, {
        "auto_resize_len": 160,
        "dummy_axis": 2,
        "alphabet_axis": 0
    })
    assert Dlc.get_output_schema().inputs.shape == (4, 160, 1)

    Dlc = override_default_kwargs(Dl, {
        "auto_resize_len": 160,
        "dummy_axis": 2,
        "alphabet_axis": 1
    })
    assert Dlc.get_output_schema().inputs.shape == (160, 4, 1)
    targets = Dlc.get_output_schema().targets
    assert targets.shape == (None, )

    Dlc = override_default_kwargs(Dl, {"ignore_targets": True})
    assert Dlc.get_output_schema().targets is None
    # reset back

    # original left intact
    assert Dl.get_output_schema().inputs.shape == (None, 4)
    assert Dl.get_output_schema().targets.shape == (None, )
예제 #3
0
파일: specs.py 프로젝트: dlhuang/kipoi
    def get(self):
        """Get the dataloader
        """
        from kipoi.data import BaseDataLoader
        from copy import deepcopy
        obj = load_obj(self.defined_as)

        # check that it inherits from BaseDataLoader
        if not inherits_from(obj, BaseDataLoader):
            raise ValueError("Dataloader: {} doen't inherit from kipoi.data.BaseDataLoader".format(self.defined_as))

        # override the default arguments
        if self.default_args:
            obj = override_default_kwargs(obj, self.default_args)

        # override also the values in the example in case
        # they were previously specified
        for k, v in six.iteritems(self.default_args):
            if not isinstance(obj.args[k].example, UNSPECIFIED):
                obj.args[k].example = v

        return obj