def hybrid_forward(self, F, data, valid_length): return gen_self_attn_mask(F, data, valid_length, dtype=self._dtype, layout=self._layout, attn_type=self._attn_type)
def test_transformer_encoder_decoder(pre_norm, num_enc_layers, num_dec_layers): batch_size = 8 src_seq_length = 20 tgt_seq_length = 15 units = 32 enc = TransformerEncoder(units=units, hidden_size=64, num_layers=num_enc_layers, num_heads=4, dropout=0.0, pre_norm=pre_norm) dec = TransformerDecoder(units=units, hidden_size=64, num_layers=num_dec_layers, num_heads=4, dropout=0.0, pre_norm=pre_norm) enc.hybridize() # disabled due to two different signatures calling attention_cell in this test # dec.hybridize() enc.initialize() dec.initialize() src_data = mx.np.random.normal(0, 1, (batch_size, src_seq_length, units)) src_valid_length = mx.np.random.randint(1, src_seq_length, (batch_size,)) dst_data = mx.np.random.normal(0, 1, (batch_size, tgt_seq_length, units)) dst_valid_length = mx.np.random.randint(5, tgt_seq_length, (batch_size,)) encoded_mem = enc(src_data, src_valid_length) full_decode_out = dec(dst_data, dst_valid_length, encoded_mem, src_valid_length) # Test for the TN layout enc_tn = TransformerEncoder(units=units, hidden_size=64, num_layers=num_enc_layers, num_heads=4, dropout=0.0, pre_norm=pre_norm, layout='TN') enc_tn.share_parameters(enc.collect_params()) dec_tn = TransformerDecoder(units=units, hidden_size=64, num_layers=num_dec_layers, num_heads=4, dropout=0.0, pre_norm=pre_norm, layout='TN') dec_tn.share_parameters(dec.collect_params()) enc_tn.hybridize() dec_tn.hybridize() encoded_mem_tn = enc_tn(mx.np.swapaxes(src_data, 0, 1), src_valid_length) full_decode_out_tn = dec_tn(mx.np.swapaxes(dst_data, 0, 1), dst_valid_length, encoded_mem_tn, src_valid_length) assert_allclose(encoded_mem_tn.asnumpy(), mx.np.swapaxes(encoded_mem, 0, 1).asnumpy(), 1E-5, 1E-5) assert_allclose(full_decode_out_tn.asnumpy(), mx.np.swapaxes(full_decode_out, 0, 1).asnumpy(), 1E-5, 1E-5) # Test the consistency via shifting the data and the valid_length for i in range(1, dst_valid_length.asnumpy().min()): for partial_decode_out in [dec(dst_data[:, :(-i), :], dst_valid_length - i, encoded_mem, src_valid_length), dec(dst_data, dst_valid_length - i, encoded_mem, src_valid_length)]: for b in range(batch_size): vl = dst_valid_length.asnumpy()[b] - i assert_allclose(partial_decode_out.asnumpy()[b, :vl, :], full_decode_out.asnumpy()[b, :vl, :], 1E-5, 1E-5) # Test the decoder layer self_causal_mask = gen_self_attn_mask(dst_data, dst_valid_length, attn_type='causal') mem_attn_mask = gen_mem_attn_mask(encoded_mem, src_valid_length, dst_data, dst_valid_length) enc_mem_attn_mask = gen_mem_attn_mask(encoded_mem, src_valid_length, dst_data[:, 0:1, :], None) print(enc_mem_attn_mask) h_out = dec.layers[0](dst_data, encoded_mem, self_causal_mask, mem_attn_mask) states = dec.layers[0].init_states(batch_size, h_out.ctx, h_out.dtype) h_out_from_incremental = [] for i in range(tgt_seq_length): ele_h_out, states = dec.layers[0].incremental_decode(dst_data[:, i, :], states, encoded_mem, src_valid_length, enc_mem_attn_mask) h_out_from_incremental.append(ele_h_out) h_out_from_incremental = mx.np.stack(h_out_from_incremental, axis=1) for i in range(batch_size): val_length = dst_valid_length[i].asnumpy() assert_allclose(h_out_from_incremental[i, :val_length, :].asnumpy(), h_out[i, :val_length, :].asnumpy(), 1E-5, 1E-5) # Test for the full decoder states = dec.init_states(batch_size, src_data.ctx, src_data.dtype) final_out_from_incremental = [] for i in range(tgt_seq_length): ele_final_out, states = dec.incremental_decode(dst_data[:, i, :], states, encoded_mem, src_valid_length) final_out_from_incremental.append(ele_final_out) final_out_from_incremental = mx.np.stack(final_out_from_incremental, axis=1) for i in range(batch_size): val_length = dst_valid_length[i].asnumpy() assert_allclose(final_out_from_incremental[i, :val_length, :].asnumpy(), full_decode_out[i, :val_length, :].asnumpy(), 1E-5, 1E-5)