def test_concatenate():
    model = nn.HybridConcatenate(axis=1)
    model.add(nn.Dense(128, activation='tanh', in_units=10))
    model.add(nn.Dense(64, activation='tanh', in_units=10))
    model.add(nn.Dense(32, in_units=10))
    model2 = nn.Concatenate(axis=1)
    model2.add(nn.Dense(128, activation='tanh', in_units=10))
    model2.add(nn.Dense(64, activation='tanh', in_units=10))
    model2.add(nn.Dense(32, in_units=10))

    # ndarray
    model.initialize(mx.init.Xavier(magnitude=2.24))
    model2.initialize(mx.init.Xavier(magnitude=2.24))
    x = model(mx.np.zeros((32, 10)))
    x2 = model2(mx.np.zeros((32, 10)))
    assert x.shape == (32, 224)
    assert x2.shape == (32, 224)
    x.wait_to_read()
    x2.wait_to_read()
 def __init__(self, input_num, dim, **kwargs):
     super(SingleConcat, self).__init__(**kwargs)
     self.concat = nn.HybridConcatenate(axis=dim)
     for i in range(input_num):
         self.concat.add(nn.Identity())