def test_index_add():
    A = np.zeros((2, INT_OVERFLOW))
    ind = np.array([[0, 0], [0, 1]], dtype='int32')
    val = np.array([100, 200])
    A.attach_grad()
    with mx.autograd.record():
        B = npx.index_add(A, ind, val)
    assert B.shape == (2, INT_OVERFLOW)
    assert B[0][0] == 100 and B[0][1] == 200
    B.backward()
    assert A.grad.shape == (2, INT_OVERFLOW)
    assert A.grad[0][0] == 1
示例#2
0
def add_vectors_by_position(data, increment, positions):
    """Scatter each batch with the given positions.

    data[i, positions[i, j], ...] += increment[i, j, ...]

    Parameters
    ----------
    F
    data
        Input tensor of the array to be updated.
        Shape (batch_size, seq_length, ...)
    increment
        Input tensor of token ids
        Shape (batch_size, num_disp_position, ...)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_disp_position).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

    Returns
    -------
    out
        The updated result.
        Shape (batch_size, seq_length, ...)
    """
    # Here, we use index_add to disperse the output from data:
    # Need to compute
    #   out[i, masked_position[i, j], :] = in[i, j, :]
    # Thus, construct an indices with shape [2, batch_size * num_masked_position], where
    #     indices[0, i * num_masked_position + j] = i
    #     indices[1, i * num_masked_position + j] = masked_position[i, j]
    # And convert data to the shape of the (batch_size * num_masked_position, )
    # Then, out = npx.index_add(data, indices, increment)
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
                               axis=1).astype(np.int32)
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))])
    out = npx.index_add(data, indices, npx.reshape(increment, (-5, -4)))
    return out