def test_coverage_attention(attention_coverage_type, attention_coverage_num_hidden, batch_size=3, encoder_num_hidden=2, decoder_num_hidden=2): # source: (batch_size, seq_len, encoder_num_hidden) source = mx.sym.Variable("source") # source_length: (batch_size, ) source_length = mx.sym.Variable("source_length") source_seq_len = 10 attention = sockeye.attention.get_attention( input_previous_word=False, attention_type="coverage", attention_num_hidden=5, rnn_num_hidden=0, max_seq_len=source_seq_len, attention_coverage_type=attention_coverage_type, attention_coverage_num_hidden=attention_coverage_num_hidden) attention_state = attention.get_initial_state(source_length, source_seq_len) attention_func = attention.on(source, source_length, source_seq_len) attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"), mx.sym.Variable("decoder_state")) attention_state = attention_func(attention_input, attention_state) sym = mx.sym.Group([ attention_state.context, attention_state.probs, attention_state.dynamic_source ]) source_shape = (batch_size, source_seq_len, encoder_num_hidden) source_length_shape = (batch_size, ) decoder_state_shape = (batch_size, decoder_num_hidden) executor = sym.simple_bind(ctx=mx.cpu(), source=source_shape, source_length=source_length_shape, decoder_state=decoder_state_shape) source_length_vector = integer_vector(shape=source_length_shape, max_value=source_seq_len) executor.arg_dict["source"][:] = gaussian_vector(shape=source_shape) executor.arg_dict["source_length"][:] = source_length_vector executor.arg_dict["decoder_state"][:] = gaussian_vector( shape=decoder_state_shape) exec_output = executor.forward() context_result = exec_output[0].asnumpy() attention_prob_result = exec_output[1].asnumpy() dynamic_source_result = exec_output[2].asnumpy() expected_probs = (1 / source_length_vector).reshape((batch_size, 1)) expected_dynamic_source = (1 / source_length_vector).reshape( (batch_size, 1)) assert context_result.shape == (batch_size, encoder_num_hidden) assert attention_prob_result.shape == (batch_size, source_seq_len) assert dynamic_source_result.shape == (batch_size, source_seq_len, attention_coverage_num_hidden) assert (np.sum(np.isclose(attention_prob_result, expected_probs), axis=1) == source_length_vector).all()
def test_attention(attention_type, batch_size=1, encoder_num_hidden=2, decoder_num_hidden=2): # source: (batch_size, seq_len, encoder_num_hidden) source = mx.sym.Variable("source") # source_length: (batch_size,) source_length = mx.sym.Variable("source_length") source_seq_len = 3 config_attention = sockeye.attention.AttentionConfig( type=attention_type, num_hidden=2, input_previous_word=False, rnn_num_hidden=2, layer_normalization=False, config_coverage=None) attention = sockeye.attention.get_attention(config_attention, max_seq_len=source_seq_len) attention_state = attention.get_initial_state(source_length, source_seq_len) attention_func = attention.on(source, source_length, source_seq_len) attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"), mx.sym.Variable("decoder_state")) attention_state = attention_func(attention_input, attention_state) sym = mx.sym.Group([attention_state.context, attention_state.probs]) executor = sym.simple_bind(ctx=mx.cpu(), source=(batch_size, source_seq_len, encoder_num_hidden), source_length=(batch_size, ), decoder_state=(batch_size, decoder_num_hidden)) # TODO: test for other inputs (that are not equal at each source position) executor.arg_dict["source"][:] = np.asarray([[[1., 2.], [1., 2.], [3., 4.]]]) executor.arg_dict["source_length"][:] = np.asarray([2.0]) executor.arg_dict["decoder_state"][:] = np.asarray([[5, 6]]) exec_output = executor.forward() context_result = exec_output[0].asnumpy() attention_prob_result = exec_output[1].asnumpy() # expecting uniform attention_weights of 0.5: 0.5 * seq1 + 0.5 * seq2 assert np.isclose(context_result, np.asarray([[1., 2.]])).all() # equal attention to first two and no attention to third assert np.isclose(attention_prob_result, np.asarray([[0.5, 0.5, 0.]])).all()
def test_last_state_attention(batch_size=1, encoder_num_hidden=2): """ EncoderLastStateAttention is a bit different from other attention mechanisms as it doesn't take a query argument and doesn't return a probability distribution over the inputs (aka alignment). """ # source: (batch_size, seq_len, encoder_num_hidden) source = mx.sym.Variable("source") # source_length: (batch_size,) source_length = mx.sym.Variable("source_length") source_seq_len = 3 config_attention = sockeye.attention.AttentionConfig( type="fixed", num_hidden=0, input_previous_word=False, rnn_num_hidden=0, layer_normalization=False, config_coverage=None) attention = sockeye.attention.get_attention(config_attention, max_seq_len=source_seq_len) attention_state = attention.get_initial_state(source_length, source_seq_len) attention_func = attention.on(source, source_length, source_seq_len) attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"), mx.sym.Variable("decoder_state")) attention_state = attention_func(attention_input, attention_state) sym = mx.sym.Group([attention_state.context, attention_state.probs]) executor = sym.simple_bind(ctx=mx.cpu(), source=(batch_size, source_seq_len, encoder_num_hidden), source_length=(batch_size, )) # TODO: test for other inputs (that are not equal at each source position) executor.arg_dict["source"][:] = np.asarray([[[1., 2.], [1., 2.], [3., 4.]]]) executor.arg_dict["source_length"][:] = np.asarray([2.0]) exec_output = executor.forward() context_result = exec_output[0].asnumpy() attention_prob_result = exec_output[1].asnumpy() # expecting attention on last state based on source_length assert np.isclose(context_result, np.asarray([[1., 2.]])).all() assert np.isclose(attention_prob_result, np.asarray([[0., 1.0, 0.]])).all()