def test_seq2col_window_one(ops, X):
    X = ops.asarray(X)
    base_ops = Ops()
    base_ops.xp = ops.xp
    baseX = base_ops.alloc(X.shape) + X
    target = base_ops.seq2col(base_ops.asarray(baseX), nW=1)
    predicted = ops.seq2col(X, nW=1)
    ops.xp.testing.assert_allclose(target, predicted, atol=0.001, rtol=0.001)
示例#2
0
def apply_alignment(ops: Ops, align: Ragged,
                    X: Floats2d) -> Tuple[Ragged, Callable]:
    """Align wordpiece data (X) to match tokens, and provide a callback to
    reverse it.
 
    This function returns a Ragged array, which represents the fact that one
    token may be aligned against multiple wordpieces. It's a nested list,
    concatenated with a lengths array to indicate the nested structure. 

    The alignment is also a Ragged array, where the lengths indicate how many
    wordpieces each token is aligned against. The output ragged therefore has
    the same lengths as the alignment ragged, which means the output data
    also has the same number of data rows as the alignment. The size of the
    lengths array indicates the number of tokens in the batch.

    The actual alignment is a simple indexing operation:

        for i, index in enumerate(align.data):
            Y[i] = X[index]

    Which is vectorized via numpy advanced indexing:
        
        Y = X[align.data]

    The inverse operation, for the backward pass, uses the 'scatter_add' op
    because one wordpiece may be aligned against multiple tokens. So we need:

        for i, index in enumerate(align.data):
            X[index] += Y[i]

    The addition wouldn't occur if we simply did `X[index] = Y`, so we use
    the scatter_add op.
    """
    if not align.lengths.sum():
        return _apply_empty_alignment(ops, align, X)
    shape = X.shape
    indices = cast(Ints1d, align.dataXd)
    Y = Ragged(X[indices], cast(Ints1d, ops.asarray(align.lengths)))

    def backprop_apply_alignment(dY: Ragged) -> Floats2d:
        assert dY.data.shape[0] == indices.shape[0]
        dX = ops.alloc2f(*shape)
        ops.scatter_add(dX, indices, cast(Floats2d, dY.dataXd))
        return dX

    return Y, backprop_apply_alignment