def relative_position_bucket(relative_position, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128): """Map the relative position to buckets. The implementation is consistent with that in [mesh_tensorflow](https://github.com/tensorflow/mesh/blob/c59988047e49b4d2af05603e3170724cdbadc467/mesh_tensorflow/transformer/transformer_layers.py#L595-L637) where relative position is defined as `mem_i - query_j`. Thus, a positive value indicates that the memory slot is in a later timestamp than the query slot. After handling the bidirectional case (see below), the implementation uses the first half of buckets to store exact differences and the second half to store the differences after a logrithmic transformation. Parameters ---------- relative_position Shape (...,) bidirectional Whether we are dealing with bidirectional attention. If it's bidirectional, positive shifts are mappd to [0, num_buckets // 2), and negative shifts are mapped to [num_buckets // 2, num_buckets). num_buckets The number of buckets. max_distance Maximum distance. Positions that fall outside of 'max_distance' will be trimmed. Returns ------- buckets Shape (...,). It has the same shape as the `relative_position`. It will have int32 type. """ ret = 0 relative_position = -relative_position if bidirectional: assert num_buckets % 2 == 0, 'When bidirectional is True, the number of buckets must be ' \ 'divisible by 2.' num_buckets //= 2 ret = ret + (relative_position < 0).astype(np.int32) * num_buckets relative_position = np.abs(relative_position) else: # Clip all the negative values to 0 relative_position = np.clip(relative_position, a_min=0, a_max=None) # Now, the relative_position is in the range [0, inf) # Half of the buckets deal with the exact increments, # i.e., 0, 1, 2, ..., max_exact - 1, where max_exact = num_buckets // 2 max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to # max_distance val_if_large = max_exact + ( np.log(relative_position.astype(np.float32) / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32) val_if_large = np.minimum(val_if_large, num_buckets - 1) ret = ret + np.where(is_small, relative_position, val_if_large) return ret
def get_rmse_log(net, X_train, y_train): """Gets root mse between the logarithms of the prediction and the truth.""" num_train = X_train.shape[0] clipped_preds = np.clip(net(X_train), 1, float('inf')) return np.sqrt( 2 * np.sum(square_loss(np.log(clipped_preds), np.log(y_train))).item() / num_train)
def test_clip(): A = np.ones((INT_OVERFLOW, 2)) A.attach_grad() with mx.autograd.record(): B = np.clip(A, 1, 1) assert B.shape == (INT_OVERFLOW, 2) assert B[0][0] == 1 B.backward() assert A.grad.shape == (INT_OVERFLOW, 2) assert A.grad[0][0] == 1
def log_rmse(net, features, labels): #To further stabilize the value when the logarithm is taken, set the #value less than 1 as 1 clipped_preds = np.clip(net(features), 1, float('inf')) return np.sqrt(2 * loss(np.log(clipped_preds), np.log(labels)).mean())
def forward(self, rel_positions, query=None): """Forward function Parameters ---------- rel_positions The relative shifts. Shape (query_length, mem_length). Each element represents the shift between the :math:`i-th` element of query and the :math:`j-th` element of memory. query The query for computing the relative scores. The shape depends on the layout. If we use T5 attention, the query will not be used. Returns ------- rel_scores The relative attention scores Can have shape (batch_size, num_heads, query_length, mem_length) or (num_heads, query_length, mem_length) """ if self._method == 'transformer_xl' or self._method == 'shaw': assert query is not None, 'Must specify query if method={}'.format(self._method) if self._bidirectional: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=-self._max_distance, a_max=self._max_distance) else: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=0, a_max=self._max_distance) # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m) uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True) uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel) if self._method == 'transformer_xl': uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed)) # Shape (#uniq, K, C_q) uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed, (-2, self._num_heads, self._head_query_units)) # Calculate the dot-product between query and the relative positional embeddings. # After the calculation, rel_score.shape = (L_q, #uniq, N, K) if self._layout == 'NKT': # query_for_rel: (N, K, L_q, C_q) if self._use_einsum: rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(query, np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'NTK': # query_for_rel: (N, L_q, K, C_q) if self._use_einsum: rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.swapaxes(query, 1, 2), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'TNK': # query_for_rel: (L_q, N, K, C_q) if self._use_einsum: rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.transpose(query, (1, 2, 0, 3)), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) else: raise NotImplementedError # We use gather_nd to select the elements # TODO(sxjscience) Use advanced indexing once available rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32) query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32), axis=-1) + np.zeros_like(rev_index) rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index])) rel_score = np.transpose(rel_score, (2, 3, 0, 1)) elif self._method == 't5': # shape is (K, L_q, L_m) rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1)) else: raise NotImplementedError return rel_score
def log_rmse(net, features, labels): # to futher stabilize the value when the log is taken # set the value less than 1 as 1 net_out = net(features) clipped_preds = np.clip(net_out, 1, float('inf')) return np.sqrt(2 * loss(np.log(clipped_preds), np.log(labels)).mean())