def test_adaptive_embedding(vocab_size, cutoffs, embed_size, units, div_val):
    embed = AdaptiveEmbedding(vocab_size=vocab_size,
                              embed_size=embed_size,
                              units=units,
                              cutoffs=cutoffs,
                              div_val=div_val)
    embed.initialize()
    embed.hybridize()
    # Test for parameter number
    estimated_param_num = 0
    if isinstance(cutoffs, int):
        cutoffs = [cutoffs]
    if div_val != 1.0:
        for i, (lhs,
                rhs) in enumerate(zip([0] + cutoffs, cutoffs + [vocab_size])):
            estimated_param_num += (rhs - lhs) * int(embed_size / div_val**i)
            estimated_param_num += int(embed_size / div_val**i) * units
        total_param_num = sum(
            [np.prod(p.shape) for p in embed.collect_params().values()])
    else:
        estimated_param_num = vocab_size * embed_size + embed_size * units
        total_param_num = sum(
            [np.prod(p.shape) for p in embed.collect_params().values()])
    assert total_param_num == estimated_param_num
    # Test for forward
    out = embed(mx.np.random.randint(0, vocab_size, 20))
    mx.npx.waitall()
    assert out.shape == (20, units)
def test_projected_adaptive_softmax(vocab_size, cutoffs, embed_size, in_units,
                                    div_val):
    layer = ProjectedAdaptiveLogSoftmaxWithLoss(vocab_size=vocab_size,
                                                cutoffs=cutoffs,
                                                embed_size=embed_size,
                                                in_units=in_units,
                                                div_val=div_val)
    layer.initialize()
    layer.hybridize()
    hidden = mx.np.random.normal(0, 1, (4, 4, 4, 16))
    target = mx.np.random.randint(0, vocab_size, (
        4,
        4,
        4,
    ))
    out = layer(hidden, target)
    mx.npx.waitall()
    assert out.shape == (4, 4, 4)

    # Test for weight sharing
    embed_layer = AdaptiveEmbedding(vocab_size=vocab_size,
                                    cutoffs=cutoffs,
                                    units=in_units,
                                    embed_size=embed_size,
                                    div_val=div_val)
    layer_with_shared_proj = \
        ProjectedAdaptiveLogSoftmaxWithLoss(vocab_size=vocab_size,
                                            cutoffs=cutoffs,
                                            embed_size=embed_size,
                                            in_units=in_units,
                                            div_val=div_val)
    layer_with_shared_proj.share_parameters(
        embed_layer.collect_params('inter_proj'))
    layer_with_shared_embed = \
        ProjectedAdaptiveLogSoftmaxWithLoss(vocab_size=vocab_size,
                                            cutoffs=cutoffs,
                                            embed_size=embed_size,
                                            in_units=in_units,
                                            div_val=div_val)
    layer_with_shared_embed.share_parameters(
        embed_layer.collect_params('embed'))
    layer_with_shared_proj_embed = \
        ProjectedAdaptiveLogSoftmaxWithLoss(vocab_size=vocab_size,
                                            cutoffs=cutoffs,
                                            embed_size=embed_size,
                                            in_units=in_units,
                                            div_val=div_val)
    layer_with_shared_proj_embed.share_parameters(
        embed_layer.collect_params('(embed|inter_proj)'))
    embed_layer.initialize()
    embed_layer.hybridize()
    layer_with_shared_proj.initialize()
    layer_with_shared_proj.hybridize()
    layer_with_shared_embed.initialize()
    layer_with_shared_embed.hybridize()
    layer_with_shared_proj_embed.initialize()
    layer_with_shared_proj_embed.hybridize()

    hidden = mx.np.random.normal(0, 1, (4, 4, 4, 16))
    target = mx.np.random.randint(0, vocab_size, (
        4,
        4,
        4,
    ))
    with mx.autograd.record():
        loss = ((hidden - embed_layer(target))**2).sum()
        loss.backward()
    assert embed_layer(target).asnumpy().shape == hidden.shape

    embed_weights = {}
    embed_grads = {}
    proj_weights = {}
    proj_grads = {}
    for k, v in embed_layer.collect_params().items():
        if '_embed' in k:
            arr_id = int(k[-len('_weight') - 1])
            embed_weights[arr_id] = v.data()[0].asnumpy()
            embed_grads[arr_id] = v.grad()[0].asnumpy()
        elif '_inter_proj' in k:
            arr_id = int(k[-len('_weight') - 1])
            proj_weights[arr_id] = v.data()[0].asnumpy()
            proj_grads[arr_id] = v.grad()[0].asnumpy()

    # Check shared proj
    for k, v in layer_with_shared_proj.collect_params().items():
        if '_embed' in k and '_weight' in k:
            arr_id = int(k[-len('_weight') - 1])
            with pytest.raises(AssertionError):
                assert_allclose(v.data()[0].asnumpy(), embed_weights[arr_id])
        elif '_inter_proj' in k and '_weight' in k:
            arr_id = int(k[-len('_weight') - 1])
            assert_allclose(v.data()[0].asnumpy(), proj_weights[arr_id])
            assert_allclose(v.grad()[0].asnumpy(), proj_grads[arr_id])

    # Check shared embed
    for k, v in layer_with_shared_embed.collect_params().items():
        if '_embed' in k and '_weight' in k:
            arr_id = int(k[-len('_weight') - 1])
            assert_allclose(v.data()[0].asnumpy(), embed_weights[arr_id])
            assert_allclose(v.grad()[0].asnumpy(), embed_grads[arr_id])
        elif '_inter_proj' in k and '_weight' in k:
            arr_id = int(k[-len('_weight') - 1])
            with pytest.raises(AssertionError):
                assert_allclose(v.data()[0].asnumpy(), proj_weights[arr_id])

    # Check shared proj + shared embed
    for k, v in layer_with_shared_proj_embed.collect_params().items():
        if '_embed' in k and '_weight' in k:
            arr_id = int(k[-len('_weight') - 1])
            assert_allclose(v.data()[0].asnumpy(), embed_weights[arr_id])
            assert_allclose(v.grad()[0].asnumpy(), embed_grads[arr_id])
        elif '_inter_proj' in k and '_weight' in k:
            arr_id = int(k[-len('_weight') - 1])
            assert_allclose(v.data()[0].asnumpy(), proj_weights[arr_id])
            assert_allclose(v.grad()[0].asnumpy(), proj_grads[arr_id])