def test_to(self, major, sub, custom, pytestconfig): tb = TransitionBase(major_attr=[m[0] for m in major], sub_attr=[s[0] for s in sub], custom_attr=[c[0] for c in custom], major_data=[m[1] for m in major], sub_data=[s[1] for s in sub], custom_data=[c[1] for c in custom]) from colorlog import getLogger logger = getLogger("") tb.to(pytestconfig.getoption("gpu_device"))
def test_dynamic_set_get(self): tb = TransitionBase(major_attr=["ma1"], sub_attr=["sa1"], custom_attr=["ca"], major_data=[{ "ma1_1": t.zeros([2, 2]) }], sub_data=[t.zeros([2, 4])], custom_data=[None]) with pytest.raises(RuntimeError, match="You cannot dynamically set"): tb["some_attr"] = 1 with pytest.raises(RuntimeError, match="You cannot dynamically set"): tb.some_attr = 1
def test_init(self, major, sub, custom, exception, match): if exception is not None: with pytest.raises(exception, match=match): _ = TransitionBase(major_attr=[m[0] for m in major], sub_attr=[s[0] for s in sub], custom_attr=[c[0] for c in custom], major_data=[m[1] for m in major], sub_data=[s[1] for s in sub], custom_data=[c[1] for c in custom]) else: _ = TransitionBase(major_attr=[m[0] for m in major], sub_attr=[s[0] for s in sub], custom_attr=[c[0] for c in custom], major_data=[m[1] for m in major], sub_data=[s[1] for s in sub], custom_data=[c[1] for c in custom])
def test_len(self, major, sub, custom, length): tb = TransitionBase(major_attr=[m[0] for m in major], sub_attr=[s[0] for s in sub], custom_attr=[c[0] for c in custom], major_data=[m[1] for m in major], sub_data=[s[1] for s in sub], custom_data=[c[1] for c in custom]) assert len(tb) == length
def test_set_get(self, major, sub, custom, key, value): tb = TransitionBase(major_attr=[m[0] for m in major], sub_attr=[s[0] for s in sub], custom_attr=[c[0] for c in custom], major_data=[m[1] for m in major], sub_data=[s[1] for s in sub], custom_data=[c[1] for c in custom]) assert t.all(tb[key] == value) assert t.all(getattr(tb, key) == value) tb[key] = value assert t.all(tb[key] == value) assert t.all(getattr(tb, key) == value)
def test_attr(self, major, sub, custom): tb = TransitionBase(major_attr=[m[0] for m in major], sub_attr=[s[0] for s in sub], custom_attr=[c[0] for c in custom], major_data=[m[1] for m in major], sub_data=[s[1] for s in sub], custom_data=[c[1] for c in custom]) assert tb.major_attr == [m[0] for m in major] assert tb.sub_attr == [s[0] for s in sub] assert tb.custom_attr == [c[0] for c in custom] assert tb.keys() == ([m[0] for m in major] + [s[0] for s in sub] + [c[0] for c in custom]) all_attr = {k: v for k, v in major + sub + custom} for k, v in tb.items(): assert k in all_attr if t.is_tensor(v) and t.is_tensor(all_attr[k]): assert t.all(all_attr[k] == v) else: assert all_attr[k] == v assert tb.has_keys(tb.keys()) assert not tb.has_keys(["cSHxn3pyd1", "53D0dape5r"])