def test_index_update():
    A = np.zeros((2, INT_OVERFLOW))
    ind = np.array([[0, 0], [0, 1]], dtype='int32')
    val = np.array([100, 200])
    A.attach_grad()
    with mx.autograd.record():
        B = npx.index_update(A, ind, val)
    assert B.shape == (2, INT_OVERFLOW)
    assert B[0][0] == 100 and B[0][1] == 200
    B.backward()
    assert A.grad.shape == (2, INT_OVERFLOW)
    assert A.grad[0][0] == 0
Esempio n. 2
0
def update_vectors_by_position(data, val, positions):
    """
    Update each batch with the given positions. Considered as a reversed process of
    "select_vectors_by_position", this is an operator similar to "add_vectors_by_position"
    that updates the results instead of adding.

    data[i, positions[i, j], :] = val[i, j, :]

    Parameters
    ----------
    F
    data:
        Input tensor of the array to be updated.
        Shape (batch_size, seq_length)
    val
        Input tensor of token ids
        Shape (batch_size, num_disp_position)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_disp_position).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

    Returns
    -------
    out
        The updated result.
        Shape (batch_size, seq_length)
    """
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
                               axis=1).astype(np.int32)
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))])

    out = npx.index_update(data, indices, npx.reshape(val, (-5, -4)))
    return out