コード例 #1
0
def test_attrscope():
    attrscope_list = []
    AttrScope.current = AttrScope(y="hi", z="hey")
    attrscope_list.append(AttrScope.current)

    def f():
        AttrScope.current = AttrScope(x="hello")
        attrscope_list.append(AttrScope.current)

    thread = threading.Thread(target=f)
    thread.start()
    thread.join()
    assert len(attrscope_list[0]._attr) == 2
    assert attrscope_list[1]._attr["x"] == "hello"

    event = threading.Event()
    status = [False]

    def g():
        with mx.AttrScope(x="hello"):
            event.wait()
            if "hello" in AttrScope.current._attr.values():
                status[0] = True

    thread = threading.Thread(target=g)
    thread.start()
    AttrScope.current = AttrScope(x="hi")
    event.set()
    thread.join()
    AttrScope.current = AttrScope()
    event.clear()
    assert status[0], "Spawned thread didn't set the correct attr key values"
コード例 #2
0
ファイル: test_thread_local.py プロジェクト: MarkMa1990/mxnet
def test_attrscope():
    attrscope_list = []
    with AttrScope(y="hi", z="hey") as attrscope:
        attrscope_list.append(attrscope)

        def f():
            with AttrScope(x="hello") as attrscope:
                attrscope_list.append(attrscope)

        thread = threading.Thread(target=f)
        thread.start()
        thread.join()
        assert len(attrscope_list[0]._attr) == 2
        assert attrscope_list[1]._attr["x"] == "hello"

    e1 = threading.Event()
    e2 = threading.Event()
    status = [False]

    def g():
        with mx.AttrScope(x="hello"):
            e2.set()
            e1.wait()
            if "hello" in mx.attribute.current()._attr.values():
                status[0] = True

    thread = threading.Thread(target=g)
    thread.start()
    e2.wait()
    with AttrScope(x="hi"):
        e1.set()
        thread.join()
    e1.clear()
    e2.clear()
    assert status[0], "Spawned thread didn't set the correct attr key values"
コード例 #3
0
def test_scope():
    class TestBlock1(gluon.HybridBlock):
        def __init__(self):
            super(TestBlock1, self).__init__()

        def forward(self, data):
            (new_data, ) = mx.npx.cond(
                pred=lambda data: data > 0.5,
                then_func=lambda data: data * 2,
                else_func=lambda data: data * 3,
                inputs=data,
                name="my_cond",
            )
            return new_data

    class TestBlock2(gluon.HybridBlock):
        def __init__(self):
            super(TestBlock2, self).__init__()

        def forward(self, data):
            (new_data, ) = mx.npx.cond(
                pred=lambda data: data > 0.5,
                then_func=lambda data: data * 2,
                else_func=lambda data: data * 3,
                inputs=data,
                name="my_cond",
            )
            return new_data

    AttrScope._subgraph_names = defaultdict(int)
    data = mx.np.random.normal(loc=0, scale=1, size=(1, ))
    with AttrScope(__subgraph_name__="my_cond"):
        block1 = TestBlock1()
        block1.initialize(ctx=default_context())
        block1.hybridize()
        _ = block1(data)
        block2 = TestBlock2()
        block2.initialize(ctx=default_context())
        block2.hybridize()
        _ = block2(data)
        assert len(AttrScope._subgraph_names) == 3
        assert AttrScope._subgraph_names['my_cond$my_cond_else'] == 2
        assert AttrScope._subgraph_names['my_cond$my_cond_pred'] == 2
        assert AttrScope._subgraph_names['my_cond$my_cond_then'] == 2
コード例 #4
0
 def f():
     AttrScope.current = AttrScope(x="hello")
     attrscope_list.append(AttrScope.current)
コード例 #5
0
ファイル: test_thread_local.py プロジェクト: MarkMa1990/mxnet
 def f():
     with AttrScope(x="hello") as attrscope:
         attrscope_list.append(attrscope)