def test_apply_alignment(nested_align, X_cols): ops = NumpyOps() align = get_ragged(ops, nested_align) X_shape = (align.data.max() + 1, X_cols) X = ops.alloc2f(*X_shape) Y, get_dX = apply_alignment(ops, align, X) assert isinstance(Y, Ragged) assert Y.data.shape[0] == align.data.shape[0] assert Y.lengths.shape[0] == len(nested_align) dX = get_dX(Y) assert dX.shape == X.shape
def get_input(nr_batch, nr_in): ops = NumpyOps() return ops.alloc2f(nr_batch, nr_in)