예제 #1
0
def test_coverage_attention(attention_coverage_type,
                            attention_coverage_num_hidden,
                            batch_size=3,
                            encoder_num_hidden=2,
                            decoder_num_hidden=2):
    # source: (batch_size, seq_len, encoder_num_hidden)
    source = mx.sym.Variable("source")
    # source_length: (batch_size, )
    source_length = mx.sym.Variable("source_length")
    source_seq_len = 10

    attention = sockeye.attention.get_attention(
        input_previous_word=False,
        attention_type="coverage",
        attention_num_hidden=5,
        rnn_num_hidden=0,
        max_seq_len=source_seq_len,
        attention_coverage_type=attention_coverage_type,
        attention_coverage_num_hidden=attention_coverage_num_hidden)
    attention_state = attention.get_initial_state(source_length,
                                                  source_seq_len)
    attention_func = attention.on(source, source_length, source_seq_len)
    attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"),
                                           mx.sym.Variable("decoder_state"))
    attention_state = attention_func(attention_input, attention_state)
    sym = mx.sym.Group([
        attention_state.context, attention_state.probs,
        attention_state.dynamic_source
    ])

    source_shape = (batch_size, source_seq_len, encoder_num_hidden)
    source_length_shape = (batch_size, )
    decoder_state_shape = (batch_size, decoder_num_hidden)

    executor = sym.simple_bind(ctx=mx.cpu(),
                               source=source_shape,
                               source_length=source_length_shape,
                               decoder_state=decoder_state_shape)

    source_length_vector = integer_vector(shape=source_length_shape,
                                          max_value=source_seq_len)
    executor.arg_dict["source"][:] = gaussian_vector(shape=source_shape)
    executor.arg_dict["source_length"][:] = source_length_vector
    executor.arg_dict["decoder_state"][:] = gaussian_vector(
        shape=decoder_state_shape)
    exec_output = executor.forward()
    context_result = exec_output[0].asnumpy()
    attention_prob_result = exec_output[1].asnumpy()
    dynamic_source_result = exec_output[2].asnumpy()

    expected_probs = (1 / source_length_vector).reshape((batch_size, 1))
    expected_dynamic_source = (1 / source_length_vector).reshape(
        (batch_size, 1))

    assert context_result.shape == (batch_size, encoder_num_hidden)
    assert attention_prob_result.shape == (batch_size, source_seq_len)
    assert dynamic_source_result.shape == (batch_size, source_seq_len,
                                           attention_coverage_num_hidden)
    assert (np.sum(np.isclose(attention_prob_result, expected_probs),
                   axis=1) == source_length_vector).all()
예제 #2
0
def test_attention(attention_type,
                   batch_size=1,
                   encoder_num_hidden=2,
                   decoder_num_hidden=2):
    # source: (batch_size, seq_len, encoder_num_hidden)
    source = mx.sym.Variable("source")
    # source_length: (batch_size,)
    source_length = mx.sym.Variable("source_length")
    source_seq_len = 3

    config_attention = sockeye.attention.AttentionConfig(
        type=attention_type,
        num_hidden=2,
        input_previous_word=False,
        rnn_num_hidden=2,
        layer_normalization=False,
        config_coverage=None)
    attention = sockeye.attention.get_attention(config_attention,
                                                max_seq_len=source_seq_len)

    attention_state = attention.get_initial_state(source_length,
                                                  source_seq_len)
    attention_func = attention.on(source, source_length, source_seq_len)
    attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"),
                                           mx.sym.Variable("decoder_state"))
    attention_state = attention_func(attention_input, attention_state)
    sym = mx.sym.Group([attention_state.context, attention_state.probs])

    executor = sym.simple_bind(ctx=mx.cpu(),
                               source=(batch_size, source_seq_len,
                                       encoder_num_hidden),
                               source_length=(batch_size, ),
                               decoder_state=(batch_size, decoder_num_hidden))

    # TODO: test for other inputs (that are not equal at each source position)
    executor.arg_dict["source"][:] = np.asarray([[[1., 2.], [1., 2.], [3.,
                                                                       4.]]])
    executor.arg_dict["source_length"][:] = np.asarray([2.0])
    executor.arg_dict["decoder_state"][:] = np.asarray([[5, 6]])
    exec_output = executor.forward()
    context_result = exec_output[0].asnumpy()
    attention_prob_result = exec_output[1].asnumpy()

    # expecting uniform attention_weights of 0.5: 0.5 * seq1 + 0.5 * seq2
    assert np.isclose(context_result, np.asarray([[1., 2.]])).all()
    # equal attention to first two and no attention to third
    assert np.isclose(attention_prob_result, np.asarray([[0.5, 0.5,
                                                          0.]])).all()
예제 #3
0
def test_last_state_attention(batch_size=1, encoder_num_hidden=2):
    """
    EncoderLastStateAttention is a bit different from other attention mechanisms as it doesn't take a query argument
    and doesn't return a probability distribution over the inputs (aka alignment).
    """
    # source: (batch_size, seq_len, encoder_num_hidden)
    source = mx.sym.Variable("source")
    # source_length: (batch_size,)
    source_length = mx.sym.Variable("source_length")
    source_seq_len = 3

    config_attention = sockeye.attention.AttentionConfig(
        type="fixed",
        num_hidden=0,
        input_previous_word=False,
        rnn_num_hidden=0,
        layer_normalization=False,
        config_coverage=None)
    attention = sockeye.attention.get_attention(config_attention,
                                                max_seq_len=source_seq_len)

    attention_state = attention.get_initial_state(source_length,
                                                  source_seq_len)
    attention_func = attention.on(source, source_length, source_seq_len)
    attention_input = attention.make_input(0, mx.sym.Variable("word_vec_prev"),
                                           mx.sym.Variable("decoder_state"))
    attention_state = attention_func(attention_input, attention_state)
    sym = mx.sym.Group([attention_state.context, attention_state.probs])

    executor = sym.simple_bind(ctx=mx.cpu(),
                               source=(batch_size, source_seq_len,
                                       encoder_num_hidden),
                               source_length=(batch_size, ))

    # TODO: test for other inputs (that are not equal at each source position)
    executor.arg_dict["source"][:] = np.asarray([[[1., 2.], [1., 2.], [3.,
                                                                       4.]]])
    executor.arg_dict["source_length"][:] = np.asarray([2.0])
    exec_output = executor.forward()
    context_result = exec_output[0].asnumpy()
    attention_prob_result = exec_output[1].asnumpy()

    # expecting attention on last state based on source_length
    assert np.isclose(context_result, np.asarray([[1., 2.]])).all()
    assert np.isclose(attention_prob_result, np.asarray([[0., 1.0, 0.]])).all()