def forward(self, data, steps): # pylint: disable=arguments-differ """ Applies positional embeddings to input data. :param data: Input data. Shape: (batch, length or 1, num_embed) :param steps: Optional steps input. If given, shape is (batch_size or 1, seq_len,) :return: Data with positional embeddings added """ # (length, num_embed) if steps is None: # (batch, length, num_embed) pos_embedding = npx.slice_like(np.expand_dims(self.weight.data(), axis=0), data, axes=(1, )) else: # (batch_size or 1, seq_len, num_embed) pos_embedding = npx.embedding(steps, self.weight.data(), self.max_seq_len, self.num_embed) if self.weight_type == 'fixed': pos_embedding = npx.stop_gradient(pos_embedding) if self.scale_up_input: data = data * (self.num_embed**0.5) return data + pos_embedding
def gumbel_softmax(logits, temperature: float = 1.0, eps: float = 1E-10, hard=True, use_np_gumbel: bool = True): r"""Perform the gumbel-softmax trick to generate differentiable one-hot vectors from the input logits. Here, the gumbel distribution is Gumbel(\alpha) = -log (-log U) + \log \alpha, in which U is the uniform(0, 1) distribution. A nice property of Gumbel is: \argmax({Gumbel(\alpha_i)}) \sim multinomial(\alpha_i) The Gumbel-Softmax trick is to use the softmax + straight-through estimator to produce one-hot vectors that represent the sampling result. References: 1. https://en.wikipedia.org/wiki/Gumbel_distribution 2. [ICLR2017] Categorical Reparameterization with Gumbel-Softmax Parameters ---------- logits Logits. Shape (..., V) temperature The temperature that controls the eps The eps for stability of gradient hard Whether to use the straight-through estimator to produce one-hot vectors. use_np_gumbel Whether to use the random.gumble operator Returns ------- ret The returned output. Shape (..., V) """ # TODO(sxjscience) Investigate the impact of random.gumbel: # Actually, random.gumble has no eps and may have problem in calculating the gradient. if use_np_gumbel: gumbels = np.random.gumbel(np.zeros_like(logits)) else: u = np.random.uniform(np.zeros_like(logits), 1) gumbels = -np.log(-np.log(u + eps) + eps) y = npx.softmax((gumbels + logits) / temperature, axis=-1) if hard: y_hard = np.max(y, axis=-1, keepdims=True) == y y_hard = npx.stop_gradient(y_hard - y) + y return y_hard else: return y
def forward(self, x: np.ndarray) -> np.ndarray: # Shape: (length, 1) length_array = npx.arange_like(x, axis=1) # matrix with lower triangle and main diagonal set to 0, upper triangle set to 1 # Shape: (length, length) bias = npx.broadcast_greater(np.expand_dims(length_array, axis=0), np.expand_dims(length_array, axis=1)) bias = bias * -C.LARGE_VALUES[self._dtype] bias = np.expand_dims(bias, axis=0) return npx.stop_gradient(bias)
def test_stop_gradient(): A = np.ones((INT_OVERFLOW, 2)) A.attach_grad() with mx.autograd.record(): B = npx.stop_gradient(A * 3) assert B.shape == (INT_OVERFLOW, 2) assert B[0][0] == 3 B.backward() # should be 3 if not for stop_gradient() assert A.grad[0][0] == 0
def dynamic_masking(self, input_ids, valid_lengths): # TODO(zheyuye), two additional flag `disallow_from_mask` and `already_masked` # that control the masking status for each positions in the sequence. """ Generate masking positions on-the-fly instead of during preprocessing Parameters ---------- input_ids The batchified input_ids with shape (batch_size, max_seq_length) valid_lengths The batchified valid_lengths with shape (batch_size, ) Returns ------ masked_input_ids The masked input sequence with 15% tokens are masked with [MASK] shape (batch_size, max_seq_length) length_masks The masking matrix for the whole sequence that indicates the positions are greater than valid_length. shape (batch_size, max_seq_length) unmasked_tokens The original tokens that appear in the unmasked input sequence shape (batch_size, num_masked_positions) masked_positions The masking positions in mx.np.ndarray with shape (batch_size, num_masked_positions) shape (batch_size, num_masked_positions) masked_lm_weights The weight matrix containing 0 or 1 to mark the actual effect of masked positions shape (batch_size, num_masked_positions) """ N = self._max_num_masked_position # Only valid token without special token are allowed to mask valid_candidates = np.ones_like(input_ids, dtype=np.bool) ignore_tokens = [ self.vocab.cls_id, self.vocab.sep_id, self.vocab.pad_id ] for ignore_token in ignore_tokens: # TODO(zheyuye), Update when operation += supported valid_candidates = valid_candidates * \ np.not_equal(input_ids, ignore_token) valid_lengths = valid_lengths.astype(np.float32) valid_candidates = valid_candidates.astype(np.float32) num_masked_position = mxnp.maximum( 1, np.minimum(N, round(valid_lengths * self._mask_prob))) # Get the masking probability of each position sample_probs = self._proposal_distribution * valid_candidates sample_probs /= mxnp.sum(sample_probs, axis=-1, keepdims=True) sample_probs = npx.stop_gradient(sample_probs) gumbels = mxnp.random.gumbel(np.zeros_like(sample_probs)) # Following the instruction of official repo to avoid deduplicate postions # with Top_k Sampling as https://github.com/google-research/electra/issues/41 masked_positions = npx.topk(mxnp.log(sample_probs) + gumbels, k=N, axis=-1, ret_typ='indices', dtype=np.int32) masked_weights = npx.sequence_mask(mxnp.ones_like(masked_positions), sequence_length=num_masked_position, use_sequence_length=True, axis=1, value=0) masked_positions = masked_positions * masked_weights length_masks = npx.sequence_mask(mxnp.ones_like(input_ids, dtype=np.float32), sequence_length=valid_lengths, use_sequence_length=True, axis=1, value=0) unmasked_tokens = select_vectors_by_position( input_ids, masked_positions) * masked_weights masked_weights = masked_weights.astype(np.float32) replaced_positions = (mxnp.random.uniform( mxnp.zeros_like(masked_positions), mxnp.ones_like( masked_positions)) < self._replace_prob) * masked_positions # dealing with multiple zero values in replaced_positions which causes # the [CLS] being replaced filled = mxnp.where(replaced_positions, self.vocab.mask_id, self.vocab.cls_id).astype(np.int32) # Masking token by replacing with [MASK] masked_input_ids = update_vectors_by_position(input_ids, filled, replaced_positions) # Note: It is likely have multiple zero values in masked_positions if number of masked of # positions not reached the maximum. However, this example hardly exists since valid_length # is almost always equal to max_seq_length masked_input = self.MaskedInput(input_ids=masked_input_ids, masks=length_masks, unmasked_tokens=unmasked_tokens, masked_positions=masked_positions, masked_weights=masked_weights) return masked_input