def test_index_update(): A = np.zeros((2, INT_OVERFLOW)) ind = np.array([[0, 0], [0, 1]], dtype='int32') val = np.array([100, 200]) A.attach_grad() with mx.autograd.record(): B = npx.index_update(A, ind, val) assert B.shape == (2, INT_OVERFLOW) assert B[0][0] == 100 and B[0][1] == 200 B.backward() assert A.grad.shape == (2, INT_OVERFLOW) assert A.grad[0][0] == 0
def update_vectors_by_position(data, val, positions): """ Update each batch with the given positions. Considered as a reversed process of "select_vectors_by_position", this is an operator similar to "add_vectors_by_position" that updates the results instead of adding. data[i, positions[i, j], :] = val[i, j, :] Parameters ---------- F data: Input tensor of the array to be updated. Shape (batch_size, seq_length) val Input tensor of token ids Shape (batch_size, num_disp_position) positions Input tensor of the positions. Shape (batch_size, num_disp_position). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The updated result. Shape (batch_size, seq_length) """ positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = np.expand_dims(npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + np.zeros_like(positions) indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))]) out = npx.index_update(data, indices, npx.reshape(val, (-5, -4))) return out