def test_reshape_like():
    A = np.ones((INT_OVERFLOW, 2))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.reshape_like(A, np.zeros((2, INT_OVERFLOW)))
    assert B.shape == (2, INT_OVERFLOW)
    assert B[0][0] == 1
    B.backward()
    assert A.grad.shape == (INT_OVERFLOW, 2)
    assert A.grad[0][0] == 1
Esempio n. 2
0
    def forward(self, rel_positions, query=None):
        """Forward function

        Parameters
        ----------
        rel_positions
            The relative shifts. Shape (query_length, mem_length).
            Each element represents the shift between the :math:`i-th` element of query and
            the :math:`j-th` element of memory.
        query
            The query for computing the relative scores. The shape depends on the layout.
            If we use T5 attention, the query will not be used.

        Returns
        -------
        rel_scores
            The relative attention scores
            Can have shape (batch_size, num_heads, query_length, mem_length)
            or (num_heads, query_length, mem_length)
        """
        if self._method == 'transformer_xl' or self._method == 'shaw':
            assert query is not None, 'Must specify query if method={}'.format(self._method)
            if self._bidirectional:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=-self._max_distance, a_max=self._max_distance)
            else:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=0, a_max=self._max_distance)
            # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m)
            uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True)

            uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel)
            if self._method == 'transformer_xl':
                uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed))
            # Shape (#uniq, K, C_q)
            uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed,
                                               (-2, self._num_heads, self._head_query_units))
            # Calculate the dot-product between query and the relative positional embeddings.
            # After the calculation, rel_score.shape = (L_q, #uniq, N, K)
            if self._layout == 'NKT':
                # query_for_rel: (N, K, L_q, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(query,
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            elif self._layout == 'NTK':
                # query_for_rel: (N, L_q, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(np.swapaxes(query, 1, 2),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            elif self._layout == 'TNK':
                # query_for_rel: (L_q, N, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(np.transpose(query, (1, 2, 0, 3)),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            else:
                raise NotImplementedError
            # We use gather_nd to select the elements
            # TODO(sxjscience) Use advanced indexing once available
            rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32)
            query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32),
                                         axis=-1) + np.zeros_like(rev_index)
            rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index]))
            rel_score = np.transpose(rel_score, (2, 3, 0, 1))
        elif self._method == 't5':
            # shape is (K, L_q, L_m)
            rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1))
        else:
            raise NotImplementedError
        return rel_score