def test_attrscope(): attrscope_list = [] AttrScope.current = AttrScope(y="hi", z="hey") attrscope_list.append(AttrScope.current) def f(): AttrScope.current = AttrScope(x="hello") attrscope_list.append(AttrScope.current) thread = threading.Thread(target=f) thread.start() thread.join() assert len(attrscope_list[0]._attr) == 2 assert attrscope_list[1]._attr["x"] == "hello" event = threading.Event() status = [False] def g(): with mx.AttrScope(x="hello"): event.wait() if "hello" in AttrScope.current._attr.values(): status[0] = True thread = threading.Thread(target=g) thread.start() AttrScope.current = AttrScope(x="hi") event.set() thread.join() AttrScope.current = AttrScope() event.clear() assert status[0], "Spawned thread didn't set the correct attr key values"
def test_attrscope(): attrscope_list = [] with AttrScope(y="hi", z="hey") as attrscope: attrscope_list.append(attrscope) def f(): with AttrScope(x="hello") as attrscope: attrscope_list.append(attrscope) thread = threading.Thread(target=f) thread.start() thread.join() assert len(attrscope_list[0]._attr) == 2 assert attrscope_list[1]._attr["x"] == "hello" e1 = threading.Event() e2 = threading.Event() status = [False] def g(): with mx.AttrScope(x="hello"): e2.set() e1.wait() if "hello" in mx.attribute.current()._attr.values(): status[0] = True thread = threading.Thread(target=g) thread.start() e2.wait() with AttrScope(x="hi"): e1.set() thread.join() e1.clear() e2.clear() assert status[0], "Spawned thread didn't set the correct attr key values"
def test_scope(): class TestBlock1(gluon.HybridBlock): def __init__(self): super(TestBlock1, self).__init__() def forward(self, data): (new_data, ) = mx.npx.cond( pred=lambda data: data > 0.5, then_func=lambda data: data * 2, else_func=lambda data: data * 3, inputs=data, name="my_cond", ) return new_data class TestBlock2(gluon.HybridBlock): def __init__(self): super(TestBlock2, self).__init__() def forward(self, data): (new_data, ) = mx.npx.cond( pred=lambda data: data > 0.5, then_func=lambda data: data * 2, else_func=lambda data: data * 3, inputs=data, name="my_cond", ) return new_data AttrScope._subgraph_names = defaultdict(int) data = mx.np.random.normal(loc=0, scale=1, size=(1, )) with AttrScope(__subgraph_name__="my_cond"): block1 = TestBlock1() block1.initialize(ctx=default_context()) block1.hybridize() _ = block1(data) block2 = TestBlock2() block2.initialize(ctx=default_context()) block2.hybridize() _ = block2(data) assert len(AttrScope._subgraph_names) == 3 assert AttrScope._subgraph_names['my_cond$my_cond_else'] == 2 assert AttrScope._subgraph_names['my_cond$my_cond_pred'] == 2 assert AttrScope._subgraph_names['my_cond$my_cond_then'] == 2
def f(): AttrScope.current = AttrScope(x="hello") attrscope_list.append(AttrScope.current)
def f(): with AttrScope(x="hello") as attrscope: attrscope_list.append(attrscope)