def test_concurrent(): model = HybridConcurrent(axis=1) model.add(nn.Dense(128, activation='tanh', in_units=10)) model.add(nn.Dense(64, activation='tanh', in_units=10)) model.add(nn.Dense(32, in_units=10)) model2 = Concurrent(axis=1) model2.add(nn.Dense(128, activation='tanh', in_units=10)) model2.add(nn.Dense(64, activation='tanh', in_units=10)) model2.add(nn.Dense(32, in_units=10)) # symbol x = mx.sym.var('data') y = model(x) assert len(y.list_arguments()) == 7 # ndarray model.initialize(mx.init.Xavier(magnitude=2.24)) model2.initialize(mx.init.Xavier(magnitude=2.24)) x = model(mx.nd.zeros((32, 10))) x2 = model2(mx.nd.zeros((32, 10))) assert x.shape == (32, 224) assert x2.shape == (32, 224) x.wait_to_read() x2.wait_to_read()