def test_np_array_scope(): np_array_scope_list = [] _NumpyArrayScope._current = _NumpyArrayScope(False) np_array_scope_list.append(_NumpyArrayScope._current) def f(): _NumpyArrayScope._current = _NumpyArrayScope(True) np_array_scope_list.append(_NumpyArrayScope._current) thread = threading.Thread(target=f) thread.start() thread.join() assert len(np_array_scope_list) == 2 assert not np_array_scope_list[0]._is_np_array assert np_array_scope_list[1]._is_np_array event = threading.Event() status = [False] def g(): with mx.np_array(False): event.wait() if not mx.is_np_array(): status[0] = True thread = threading.Thread(target=g) thread.start() _NumpyArrayScope._current = _NumpyArrayScope(True) event.set() thread.join() event.clear() assert status[0], "Spawned thread didn't set status correctly"
def f(): _NumpyArrayScope._current = _NumpyArrayScope(True) np_array_scope_list.append(_NumpyArrayScope._current)