示例#1
0
 def hybrid_forward(self, F, data, valid_length):
     return gen_self_attn_mask(F,
                               data,
                               valid_length,
                               dtype=self._dtype,
                               layout=self._layout,
                               attn_type=self._attn_type)
示例#2
0
def test_transformer_encoder_decoder(pre_norm, num_enc_layers, num_dec_layers):
    batch_size = 8
    src_seq_length = 20
    tgt_seq_length = 15
    units = 32
    enc = TransformerEncoder(units=units, hidden_size=64, num_layers=num_enc_layers, num_heads=4,
                             dropout=0.0, pre_norm=pre_norm)
    dec = TransformerDecoder(units=units, hidden_size=64, num_layers=num_dec_layers, num_heads=4,
                             dropout=0.0, pre_norm=pre_norm)
    enc.hybridize()
    # disabled due to two different signatures calling attention_cell in this test
    # dec.hybridize()
    enc.initialize()
    dec.initialize()
    src_data = mx.np.random.normal(0, 1, (batch_size, src_seq_length, units))
    src_valid_length = mx.np.random.randint(1, src_seq_length, (batch_size,))
    dst_data = mx.np.random.normal(0, 1, (batch_size, tgt_seq_length, units))
    dst_valid_length = mx.np.random.randint(5, tgt_seq_length, (batch_size,))
    encoded_mem = enc(src_data, src_valid_length)
    full_decode_out = dec(dst_data, dst_valid_length, encoded_mem, src_valid_length)

    # Test for the TN layout
    enc_tn = TransformerEncoder(units=units, hidden_size=64, num_layers=num_enc_layers, num_heads=4,
                                dropout=0.0, pre_norm=pre_norm, layout='TN')
    enc_tn.share_parameters(enc.collect_params())
    dec_tn = TransformerDecoder(units=units, hidden_size=64, num_layers=num_dec_layers, num_heads=4,
                                dropout=0.0, pre_norm=pre_norm, layout='TN')
    dec_tn.share_parameters(dec.collect_params())
    enc_tn.hybridize()
    dec_tn.hybridize()
    encoded_mem_tn = enc_tn(mx.np.swapaxes(src_data, 0, 1), src_valid_length)
    full_decode_out_tn = dec_tn(mx.np.swapaxes(dst_data, 0, 1), dst_valid_length,
                                encoded_mem_tn, src_valid_length)
    assert_allclose(encoded_mem_tn.asnumpy(),
                    mx.np.swapaxes(encoded_mem, 0, 1).asnumpy(), 1E-5, 1E-5)
    assert_allclose(full_decode_out_tn.asnumpy(),
                    mx.np.swapaxes(full_decode_out, 0, 1).asnumpy(), 1E-5, 1E-5)

    # Test the consistency via shifting the data and the valid_length
    for i in range(1, dst_valid_length.asnumpy().min()):
        for partial_decode_out in [dec(dst_data[:, :(-i), :],
                                       dst_valid_length - i, encoded_mem, src_valid_length),
                                   dec(dst_data, dst_valid_length - i,
                                       encoded_mem, src_valid_length)]:
            for b in range(batch_size):
                vl = dst_valid_length.asnumpy()[b] - i
                assert_allclose(partial_decode_out.asnumpy()[b, :vl, :],
                                full_decode_out.asnumpy()[b, :vl, :], 1E-5, 1E-5)
    # Test the decoder layer
    self_causal_mask = gen_self_attn_mask(dst_data, dst_valid_length, attn_type='causal')
    mem_attn_mask = gen_mem_attn_mask(encoded_mem, src_valid_length, dst_data, dst_valid_length)
    enc_mem_attn_mask = gen_mem_attn_mask(encoded_mem, src_valid_length, dst_data[:, 0:1, :],
                                          None)
    print(enc_mem_attn_mask)
    h_out = dec.layers[0](dst_data, encoded_mem, self_causal_mask, mem_attn_mask)
    states = dec.layers[0].init_states(batch_size, h_out.ctx, h_out.dtype)
    h_out_from_incremental = []
    for i in range(tgt_seq_length):
        ele_h_out, states = dec.layers[0].incremental_decode(dst_data[:, i, :], states,
                                                             encoded_mem, src_valid_length,
                                                             enc_mem_attn_mask)
        h_out_from_incremental.append(ele_h_out)
    h_out_from_incremental = mx.np.stack(h_out_from_incremental, axis=1)

    for i in range(batch_size):
        val_length = dst_valid_length[i].asnumpy()
        assert_allclose(h_out_from_incremental[i, :val_length, :].asnumpy(),
                        h_out[i, :val_length, :].asnumpy(), 1E-5, 1E-5)
    # Test for the full decoder
    states = dec.init_states(batch_size, src_data.ctx, src_data.dtype)
    final_out_from_incremental = []
    for i in range(tgt_seq_length):
        ele_final_out, states = dec.incremental_decode(dst_data[:, i, :],
                                                       states, encoded_mem, src_valid_length)
        final_out_from_incremental.append(ele_final_out)
    final_out_from_incremental = mx.np.stack(final_out_from_incremental, axis=1)
    for i in range(batch_size):
        val_length = dst_valid_length[i].asnumpy()
        assert_allclose(final_out_from_incremental[i, :val_length, :].asnumpy(),
                        full_decode_out[i, :val_length, :].asnumpy(), 1E-5, 1E-5)