def test_adaptive_embedding(vocab_size, cutoffs, embed_size, units, div_val): embed = AdaptiveEmbedding(vocab_size=vocab_size, embed_size=embed_size, units=units, cutoffs=cutoffs, div_val=div_val) embed.initialize() embed.hybridize() # Test for parameter number estimated_param_num = 0 if isinstance(cutoffs, int): cutoffs = [cutoffs] if div_val != 1.0: for i, (lhs, rhs) in enumerate(zip([0] + cutoffs, cutoffs + [vocab_size])): estimated_param_num += (rhs - lhs) * int(embed_size / div_val**i) estimated_param_num += int(embed_size / div_val**i) * units total_param_num = sum( [np.prod(p.shape) for p in embed.collect_params().values()]) else: estimated_param_num = vocab_size * embed_size + embed_size * units total_param_num = sum( [np.prod(p.shape) for p in embed.collect_params().values()]) assert total_param_num == estimated_param_num # Test for forward out = embed(mx.np.random.randint(0, vocab_size, 20)) mx.npx.waitall() assert out.shape == (20, units)
def test_projected_adaptive_softmax(vocab_size, cutoffs, embed_size, in_units, div_val): layer = ProjectedAdaptiveLogSoftmaxWithLoss(vocab_size=vocab_size, cutoffs=cutoffs, embed_size=embed_size, in_units=in_units, div_val=div_val) layer.initialize() layer.hybridize() hidden = mx.np.random.normal(0, 1, (4, 4, 4, 16)) target = mx.np.random.randint(0, vocab_size, ( 4, 4, 4, )) out = layer(hidden, target) mx.npx.waitall() assert out.shape == (4, 4, 4) # Test for weight sharing embed_layer = AdaptiveEmbedding(vocab_size=vocab_size, cutoffs=cutoffs, units=in_units, embed_size=embed_size, div_val=div_val) layer_with_shared_proj = \ ProjectedAdaptiveLogSoftmaxWithLoss(vocab_size=vocab_size, cutoffs=cutoffs, embed_size=embed_size, in_units=in_units, div_val=div_val) layer_with_shared_proj.share_parameters( embed_layer.collect_params('inter_proj')) layer_with_shared_embed = \ ProjectedAdaptiveLogSoftmaxWithLoss(vocab_size=vocab_size, cutoffs=cutoffs, embed_size=embed_size, in_units=in_units, div_val=div_val) layer_with_shared_embed.share_parameters( embed_layer.collect_params('embed')) layer_with_shared_proj_embed = \ ProjectedAdaptiveLogSoftmaxWithLoss(vocab_size=vocab_size, cutoffs=cutoffs, embed_size=embed_size, in_units=in_units, div_val=div_val) layer_with_shared_proj_embed.share_parameters( embed_layer.collect_params('(embed|inter_proj)')) embed_layer.initialize() embed_layer.hybridize() layer_with_shared_proj.initialize() layer_with_shared_proj.hybridize() layer_with_shared_embed.initialize() layer_with_shared_embed.hybridize() layer_with_shared_proj_embed.initialize() layer_with_shared_proj_embed.hybridize() hidden = mx.np.random.normal(0, 1, (4, 4, 4, 16)) target = mx.np.random.randint(0, vocab_size, ( 4, 4, 4, )) with mx.autograd.record(): loss = ((hidden - embed_layer(target))**2).sum() loss.backward() assert embed_layer(target).asnumpy().shape == hidden.shape embed_weights = {} embed_grads = {} proj_weights = {} proj_grads = {} for k, v in embed_layer.collect_params().items(): if '_embed' in k: arr_id = int(k[-len('_weight') - 1]) embed_weights[arr_id] = v.data()[0].asnumpy() embed_grads[arr_id] = v.grad()[0].asnumpy() elif '_inter_proj' in k: arr_id = int(k[-len('_weight') - 1]) proj_weights[arr_id] = v.data()[0].asnumpy() proj_grads[arr_id] = v.grad()[0].asnumpy() # Check shared proj for k, v in layer_with_shared_proj.collect_params().items(): if '_embed' in k and '_weight' in k: arr_id = int(k[-len('_weight') - 1]) with pytest.raises(AssertionError): assert_allclose(v.data()[0].asnumpy(), embed_weights[arr_id]) elif '_inter_proj' in k and '_weight' in k: arr_id = int(k[-len('_weight') - 1]) assert_allclose(v.data()[0].asnumpy(), proj_weights[arr_id]) assert_allclose(v.grad()[0].asnumpy(), proj_grads[arr_id]) # Check shared embed for k, v in layer_with_shared_embed.collect_params().items(): if '_embed' in k and '_weight' in k: arr_id = int(k[-len('_weight') - 1]) assert_allclose(v.data()[0].asnumpy(), embed_weights[arr_id]) assert_allclose(v.grad()[0].asnumpy(), embed_grads[arr_id]) elif '_inter_proj' in k and '_weight' in k: arr_id = int(k[-len('_weight') - 1]) with pytest.raises(AssertionError): assert_allclose(v.data()[0].asnumpy(), proj_weights[arr_id]) # Check shared proj + shared embed for k, v in layer_with_shared_proj_embed.collect_params().items(): if '_embed' in k and '_weight' in k: arr_id = int(k[-len('_weight') - 1]) assert_allclose(v.data()[0].asnumpy(), embed_weights[arr_id]) assert_allclose(v.grad()[0].asnumpy(), embed_grads[arr_id]) elif '_inter_proj' in k and '_weight' in k: arr_id = int(k[-len('_weight') - 1]) assert_allclose(v.data()[0].asnumpy(), proj_weights[arr_id]) assert_allclose(v.grad()[0].asnumpy(), proj_grads[arr_id])