def test_index_add(): A = np.zeros((2, INT_OVERFLOW)) ind = np.array([[0, 0], [0, 1]], dtype='int32') val = np.array([100, 200]) A.attach_grad() with mx.autograd.record(): B = npx.index_add(A, ind, val) assert B.shape == (2, INT_OVERFLOW) assert B[0][0] == 100 and B[0][1] == 200 B.backward() assert A.grad.shape == (2, INT_OVERFLOW) assert A.grad[0][0] == 1
def add_vectors_by_position(data, increment, positions): """Scatter each batch with the given positions. data[i, positions[i, j], ...] += increment[i, j, ...] Parameters ---------- F data Input tensor of the array to be updated. Shape (batch_size, seq_length, ...) increment Input tensor of token ids Shape (batch_size, num_disp_position, ...) positions Input tensor of the positions. Shape (batch_size, num_disp_position). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The updated result. Shape (batch_size, seq_length, ...) """ # Here, we use index_add to disperse the output from data: # Need to compute # out[i, masked_position[i, j], :] = in[i, j, :] # Thus, construct an indices with shape [2, batch_size * num_masked_position], where # indices[0, i * num_masked_position + j] = i # indices[1, i * num_masked_position + j] = masked_position[i, j] # And convert data to the shape of the (batch_size * num_masked_position, ) # Then, out = npx.index_add(data, indices, increment) positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = np.expand_dims(npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + np.zeros_like(positions) indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))]) out = npx.index_add(data, indices, npx.reshape(increment, (-5, -4))) return out