def test_gather_nd(): A = np.ones((1, 2, INT_OVERFLOW)) A[0, 1, 100] = 100 A.attach_grad() with mx.autograd.record(): B = npx.gather_nd(data=A, \ indices=np.array([[0, 0] , [0, 1], [INT_OVERFLOW-1, 100]], \ dtype='int64')) assert B.shape == (2, ) assert B[0] == 1 and B[1] == 100 B.backward() assert A.grad.shape == (1, 2, INT_OVERFLOW) assert A.grad[0, 0, 0] == 0 assert A.grad[0, 0, INT_OVERFLOW - 1] == 1
def select_vectors_by_position(data, positions): """Select each batch with the given positions. Once advanced indexing can be hybridized, we can revise the implementation. out[i, j, ...] = data[i, positions[i, j], ...] Parameters ---------- data Input tensor of contextualized token embeddings Shape (batch_size, seq_length, ...) positions Input tensor of the positions. Shape (batch_size, num_sel_positions). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The selection result. Shape (batch_size, num_sel_positions, ...) """ # Here, we use gather_nd to select the output from data: # Need to compute # out[i, j, :] = in[i, masked_position[i, j], :] # Thus, construct a indices with shape [2, batch_size, num_masked_position], where # indices[0, i, j] = i # indices[1, i, j] = masked_position[i, j] # Then, out = gather_nd(in, indices) positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = np.expand_dims(npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + np.zeros_like(positions) indices = np.stack([batch_idx, positions]) # TODO(sxjscience) We can revise the implementation to advanced indexing # once the bug in MXNet is solved: # out = npx.gather_nd(data, indices) return out
def forward(self, rel_positions, query=None): """Forward function Parameters ---------- rel_positions The relative shifts. Shape (query_length, mem_length). Each element represents the shift between the :math:`i-th` element of query and the :math:`j-th` element of memory. query The query for computing the relative scores. The shape depends on the layout. If we use T5 attention, the query will not be used. Returns ------- rel_scores The relative attention scores Can have shape (batch_size, num_heads, query_length, mem_length) or (num_heads, query_length, mem_length) """ if self._method == 'transformer_xl' or self._method == 'shaw': assert query is not None, 'Must specify query if method={}'.format(self._method) if self._bidirectional: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=-self._max_distance, a_max=self._max_distance) else: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=0, a_max=self._max_distance) # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m) uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True) uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel) if self._method == 'transformer_xl': uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed)) # Shape (#uniq, K, C_q) uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed, (-2, self._num_heads, self._head_query_units)) # Calculate the dot-product between query and the relative positional embeddings. # After the calculation, rel_score.shape = (L_q, #uniq, N, K) if self._layout == 'NKT': # query_for_rel: (N, K, L_q, C_q) if self._use_einsum: rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(query, np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'NTK': # query_for_rel: (N, L_q, K, C_q) if self._use_einsum: rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.swapaxes(query, 1, 2), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'TNK': # query_for_rel: (L_q, N, K, C_q) if self._use_einsum: rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.transpose(query, (1, 2, 0, 3)), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) else: raise NotImplementedError # We use gather_nd to select the elements # TODO(sxjscience) Use advanced indexing once available rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32) query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32), axis=-1) + np.zeros_like(rev_index) rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index])) rel_score = np.transpose(rel_score, (2, 3, 0, 1)) elif self._method == 't5': # shape is (K, L_q, L_m) rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1)) else: raise NotImplementedError return rel_score