Example #1
0
 def test_to(self, major, sub, custom, pytestconfig):
     tb = TransitionBase(major_attr=[m[0] for m in major],
                         sub_attr=[s[0] for s in sub],
                         custom_attr=[c[0] for c in custom],
                         major_data=[m[1] for m in major],
                         sub_data=[s[1] for s in sub],
                         custom_data=[c[1] for c in custom])
     from colorlog import getLogger
     logger = getLogger("")
     tb.to(pytestconfig.getoption("gpu_device"))
Example #2
0
 def test_dynamic_set_get(self):
     tb = TransitionBase(major_attr=["ma1"],
                         sub_attr=["sa1"],
                         custom_attr=["ca"],
                         major_data=[{
                             "ma1_1": t.zeros([2, 2])
                         }],
                         sub_data=[t.zeros([2, 4])],
                         custom_data=[None])
     with pytest.raises(RuntimeError, match="You cannot dynamically set"):
         tb["some_attr"] = 1
     with pytest.raises(RuntimeError, match="You cannot dynamically set"):
         tb.some_attr = 1
Example #3
0
 def test_init(self, major, sub, custom, exception, match):
     if exception is not None:
         with pytest.raises(exception, match=match):
             _ = TransitionBase(major_attr=[m[0] for m in major],
                                sub_attr=[s[0] for s in sub],
                                custom_attr=[c[0] for c in custom],
                                major_data=[m[1] for m in major],
                                sub_data=[s[1] for s in sub],
                                custom_data=[c[1] for c in custom])
     else:
         _ = TransitionBase(major_attr=[m[0] for m in major],
                            sub_attr=[s[0] for s in sub],
                            custom_attr=[c[0] for c in custom],
                            major_data=[m[1] for m in major],
                            sub_data=[s[1] for s in sub],
                            custom_data=[c[1] for c in custom])
Example #4
0
 def test_len(self, major, sub, custom, length):
     tb = TransitionBase(major_attr=[m[0] for m in major],
                         sub_attr=[s[0] for s in sub],
                         custom_attr=[c[0] for c in custom],
                         major_data=[m[1] for m in major],
                         sub_data=[s[1] for s in sub],
                         custom_data=[c[1] for c in custom])
     assert len(tb) == length
Example #5
0
 def test_set_get(self, major, sub, custom, key, value):
     tb = TransitionBase(major_attr=[m[0] for m in major],
                         sub_attr=[s[0] for s in sub],
                         custom_attr=[c[0] for c in custom],
                         major_data=[m[1] for m in major],
                         sub_data=[s[1] for s in sub],
                         custom_data=[c[1] for c in custom])
     assert t.all(tb[key] == value)
     assert t.all(getattr(tb, key) == value)
     tb[key] = value
     assert t.all(tb[key] == value)
     assert t.all(getattr(tb, key) == value)
Example #6
0
 def test_attr(self, major, sub, custom):
     tb = TransitionBase(major_attr=[m[0] for m in major],
                         sub_attr=[s[0] for s in sub],
                         custom_attr=[c[0] for c in custom],
                         major_data=[m[1] for m in major],
                         sub_data=[s[1] for s in sub],
                         custom_data=[c[1] for c in custom])
     assert tb.major_attr == [m[0] for m in major]
     assert tb.sub_attr == [s[0] for s in sub]
     assert tb.custom_attr == [c[0] for c in custom]
     assert tb.keys() == ([m[0] for m in major] + [s[0] for s in sub] +
                          [c[0] for c in custom])
     all_attr = {k: v for k, v in major + sub + custom}
     for k, v in tb.items():
         assert k in all_attr
         if t.is_tensor(v) and t.is_tensor(all_attr[k]):
             assert t.all(all_attr[k] == v)
         else:
             assert all_attr[k] == v
     assert tb.has_keys(tb.keys())
     assert not tb.has_keys(["cSHxn3pyd1", "53D0dape5r"])