Esempio n. 1
0
def test_scaled_dot_product_attention5():
    """
    value different dimension to query and key, AND
    query and key can have different sequence length, AND
    key and value must have same sequence length
    """
    s1, s2 = 4, 2
    s3, s4 = 5, 10

    seq1 = C.Axis.new_unique_dynamic_axis('seq1')
    seq2 = C.Axis.new_unique_dynamic_axis('seq2')

    query = C.sequence.input_variable(5, sequence_axis=seq1)
    key = C.sequence.input_variable(5, sequence_axis=seq2)
    value = C.sequence.input_variable(7, sequence_axis=seq2)

    b = ScaledDotProductAttention()(query, key, value)

    q1 = np.random.random((s3, 5)).astype(np.float32)
    q2 = np.random.random((s4, 5)).astype(np.float32)

    k1 = np.random.random((s1, 5)).astype(np.float32)
    k2 = np.random.random((s2, 5)).astype(np.float32)

    v1 = np.random.random((s1, 7)).astype(np.float32)
    v2 = np.random.random((s2, 7)).astype(np.float32)

    results = b.eval({query: [q1, q2], key: [k1, k2], value: [v1, v2]})
Esempio n. 2
0
def test_scaled_dot_product_attention6():
    """
    key and value must have same sequence length
    """
    s1, s2 = 4, 2
    s3, s4 = 5, 10

    seq1 = C.Axis.new_unique_dynamic_axis('seq1')
    seq2 = C.Axis.new_unique_dynamic_axis('seq2')

    query = C.sequence.input_variable(5, sequence_axis=seq1)
    key = C.sequence.input_variable(5, sequence_axis=seq1)
    value = C.sequence.input_variable(7, sequence_axis=seq2)

    b = ScaledDotProductAttention()(query, key, value)

    q1 = np.random.random((s3, 5)).astype(np.float32)
    q2 = np.random.random((s4, 5)).astype(np.float32)

    k1 = np.random.random((s3, 5)).astype(np.float32)
    k2 = np.random.random((s4, 5)).astype(np.float32)

    v1 = np.random.random((s1, 7)).astype(np.float32)
    v2 = np.random.random((s2, 7)).astype(np.float32)

    with pytest.raises(Exception):
        results = b.eval({query: [q1, q2], key: [k1, k2], value: [v1, v2]})
Esempio n. 3
0
def test_scaled_dot_product_attention1():
    """ check default works """
    s1, s2 = 4, 2
    a = C.sequence.input_variable(5)
    b = ScaledDotProductAttention()(a, a, a)

    assert b.shape == (5, ), "output should be a sequence and dimension should not change"

    n1 = np.random.random((s1, 5)).astype(np.float32)
    n2 = np.random.random((s2, 5)).astype(np.float32)

    results = b.eval({a: [n1, n2]})
    assert results[1].shape == n2.shape, f"Wrong expected shape {results[1].shape} != {n2.shape}"
    assert results[0].shape == n1.shape, f"Wrong expected shape {results[0].shape} != {n1.shape}"
Esempio n. 4
0
def test_scaled_dot_product_attention2():
    """ returns a sequence while not peeking on future values """
    s1, s2 = 4, 2
    a = C.sequence.input_variable(5)
    b = ScaledDotProductAttention(obey_sequence_order=True, max_seq_len=100)(a, a, a)

    assert b.shape == (5, ), "output should be a sequence and dimension should not change"

    n1 = np.random.random((s1, 5)).astype(np.float32)
    n2 = np.random.random((s2, 5)).astype(np.float32)

    results = b.eval({a: [n1, n2]})
    assert results[1].shape == n2.shape, f"Wrong expected shape {results[1].shape} != {n2.shape}"
    assert results[0].shape == n1.shape, f"Wrong expected shape {results[0].shape} != {n1.shape}"
Esempio n. 5
0
def test_scaled_dot_product_attention3():
    """ query and key-value musts have same dimensions """
    query = C.sequence.input_variable(5)
    keyvalue = C.sequence.input_variable(20)

    with pytest.raises(Exception):
        b = ScaledDotProductAttention()(query, keyvalue, keyvalue)
Esempio n. 6
0
def test_scaled_dot_product_attention4():
    """ value can be of a completely different dimensions as the query or key """
    s1, s2 = 4, 2

    query = C.sequence.input_variable(5)
    key = C.sequence.input_variable(5)
    value = C.sequence.input_variable(7)

    b = ScaledDotProductAttention()(query, key, value)

    n1 = np.random.random((s1, 5)).astype(np.float32)
    n2 = np.random.random((s2, 5)).astype(np.float32)

    m1 = np.random.random((s1, 7)).astype(np.float32)
    m2 = np.random.random((s2, 7)).astype(np.float32)

    results = b.eval({query: [n1, n2], key: [n1, n2], value: [m1, m2]})