Ejemplo n.º 1
0
def test_extract_spans_span_indices():
    model = extract_spans().initialize()
    spans = Ragged(
        model.ops.asarray([[0, 3], [2, 3], [5, 7]], dtype="i"),
        model.ops.asarray([2, 1], dtype="i"),
    )
    x_lengths = model.ops.asarray([5, 10], dtype="i")
    indices = _get_span_indices(model.ops, spans, x_lengths)
    assert list(indices) == [0, 1, 2, 2, 10, 11]
Ejemplo n.º 2
0
def test_extract_spans_forward_backward():
    model = extract_spans().initialize()
    X = Ragged(model.ops.alloc2f(15, 4), model.ops.asarray([5, 10], dtype="i"))
    spans = Ragged(
        model.ops.asarray([[0, 3], [2, 3], [5, 7]], dtype="i"),
        model.ops.asarray([2, 1], dtype="i"),
    )
    Y, backprop = model.begin_update((X, spans))
    assert list(Y.lengths) == [3, 1, 2]
    assert Y.dataXd.shape == (6, 4)
    dX, spans2 = backprop(Y)
    assert spans2 is spans
    assert dX.dataXd.shape == X.dataXd.shape
    assert list(dX.lengths) == list(X.lengths)
Ejemplo n.º 3
0
def test_init_extract_spans():
    model = extract_spans().initialize()
Ejemplo n.º 4
0
def test_init_extract_spans():
    extract_spans().initialize()