Esempio n. 1
0
def relative_position_bucket(relative_position,
                             bidirectional: bool = True,
                             num_buckets: int = 32,
                             max_distance: int = 128):
    """Map the relative position to buckets. The implementation is consistent with that
    in [mesh_tensorflow](https://github.com/tensorflow/mesh/blob/c59988047e49b4d2af05603e3170724cdbadc467/mesh_tensorflow/transformer/transformer_layers.py#L595-L637)
    where relative position is defined as `mem_i - query_j`. Thus, a positive value indicates 
    that the memory slot is in a later timestamp than the query slot. 

    After handling the bidirectional case (see below), the implementation uses the first half 
    of buckets to store exact differences and the second half to store the differences after 
    a logrithmic transformation. 

    Parameters
    ----------
    relative_position
        Shape (...,)
    bidirectional
        Whether we are dealing with bidirectional attention.
        If it's bidirectional, positive shifts are mappd to [0, num_buckets // 2), 
        and negative shifts are mapped to [num_buckets // 2, num_buckets). 
    num_buckets
        The number of buckets.
    max_distance
        Maximum distance. Positions that fall outside of 'max_distance' will be trimmed.

    Returns
    -------
    buckets
        Shape (...,).
        It has the same shape as the `relative_position`. It will have int32 type.
    """
    ret = 0
    relative_position = -relative_position
    if bidirectional:
        assert num_buckets % 2 == 0, 'When bidirectional is True, the number of buckets must be ' \
                                     'divisible by 2.'
        num_buckets //= 2
        ret = ret + (relative_position < 0).astype(np.int32) * num_buckets
        relative_position = np.abs(relative_position)
    else:
        # Clip all the negative values to 0
        relative_position = np.clip(relative_position, a_min=0, a_max=None)
    # Now, the relative_position is in the range [0, inf)

    # Half of the buckets deal with the exact increments,
    # i.e., 0, 1, 2, ..., max_exact - 1, where max_exact = num_buckets // 2
    max_exact = num_buckets // 2
    is_small = relative_position < max_exact

    # The other half of the buckets are for logarithmically bigger bins in positions up to
    # max_distance
    val_if_large = max_exact + (
        np.log(relative_position.astype(np.float32) / max_exact) /
        math.log(max_distance / max_exact) *
        (num_buckets - max_exact)).astype(np.int32)
    val_if_large = np.minimum(val_if_large, num_buckets - 1)
    ret = ret + np.where(is_small, relative_position, val_if_large)
    return ret
Esempio n. 2
0
def get_rmse_log(net, X_train, y_train):
    """Gets root mse between the logarithms of the prediction and the truth."""
    num_train = X_train.shape[0]
    clipped_preds = np.clip(net(X_train), 1, float('inf'))
    return np.sqrt(
        2 *
        np.sum(square_loss(np.log(clipped_preds), np.log(y_train))).item() /
        num_train)
def test_clip():
    A = np.ones((INT_OVERFLOW, 2))
    A.attach_grad()
    with mx.autograd.record():
        B = np.clip(A, 1, 1)
    assert B.shape == (INT_OVERFLOW, 2)
    assert B[0][0] == 1
    B.backward()
    assert A.grad.shape == (INT_OVERFLOW, 2)
    assert A.grad[0][0] == 1
Esempio n. 4
0
def log_rmse(net, features, labels):
    #To further stabilize the value when the logarithm is taken, set the
    #value less than 1 as 1
    clipped_preds = np.clip(net(features), 1, float('inf'))
    return np.sqrt(2 * loss(np.log(clipped_preds), np.log(labels)).mean())
Esempio n. 5
0
    def forward(self, rel_positions, query=None):
        """Forward function

        Parameters
        ----------
        rel_positions
            The relative shifts. Shape (query_length, mem_length).
            Each element represents the shift between the :math:`i-th` element of query and
            the :math:`j-th` element of memory.
        query
            The query for computing the relative scores. The shape depends on the layout.
            If we use T5 attention, the query will not be used.

        Returns
        -------
        rel_scores
            The relative attention scores
            Can have shape (batch_size, num_heads, query_length, mem_length)
            or (num_heads, query_length, mem_length)
        """
        if self._method == 'transformer_xl' or self._method == 'shaw':
            assert query is not None, 'Must specify query if method={}'.format(self._method)
            if self._bidirectional:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=-self._max_distance, a_max=self._max_distance)
            else:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=0, a_max=self._max_distance)
            # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m)
            uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True)

            uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel)
            if self._method == 'transformer_xl':
                uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed))
            # Shape (#uniq, K, C_q)
            uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed,
                                               (-2, self._num_heads, self._head_query_units))
            # Calculate the dot-product between query and the relative positional embeddings.
            # After the calculation, rel_score.shape = (L_q, #uniq, N, K)
            if self._layout == 'NKT':
                # query_for_rel: (N, K, L_q, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(query,
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            elif self._layout == 'NTK':
                # query_for_rel: (N, L_q, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(np.swapaxes(query, 1, 2),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            elif self._layout == 'TNK':
                # query_for_rel: (L_q, N, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(np.transpose(query, (1, 2, 0, 3)),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            else:
                raise NotImplementedError
            # We use gather_nd to select the elements
            # TODO(sxjscience) Use advanced indexing once available
            rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32)
            query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32),
                                         axis=-1) + np.zeros_like(rev_index)
            rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index]))
            rel_score = np.transpose(rel_score, (2, 3, 0, 1))
        elif self._method == 't5':
            # shape is (K, L_q, L_m)
            rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1))
        else:
            raise NotImplementedError
        return rel_score
Esempio n. 6
0
def log_rmse(net, features, labels):
    # to futher stabilize the value when the log is taken
    # set the value less than 1 as 1
    net_out = net(features)
    clipped_preds = np.clip(net_out, 1, float('inf'))
    return np.sqrt(2 * loss(np.log(clipped_preds), np.log(labels)).mean())