def inference(self, tokens, token_types, valid_length, p_mask, start_top_n: int = 5, end_top_n: int = 5): """Get the inference result with beam search Parameters ---------- tokens The input tokens. Shape (batch_size, sequence_length) token_types The input token types. Shape (batch_size, sequence_length) valid_length The valid length of the tokens. Shape (batch_size,) p_mask The mask which indicates that some tokens won't be used in the calculation. Shape (batch_size, sequence_length) start_top_n The number of candidates to select for the start position. end_top_n The number of candidates to select for the end position. Returns ------- start_top_logits The top start logits Shape (batch_size, start_top_n) start_top_index Index of the top start logits Shape (batch_size, start_top_n) end_top_logits The top end logits. Shape (batch_size, start_top_n, end_top_n) end_top_index Index of the top end logits Shape (batch_size, start_top_n, end_top_n) answerable_logits The answerable logits. Here 0 --> answerable and 1 --> not answerable. Shape (batch_size, sequence_length, 2) """ # Shape (batch_size, sequence_length, C) if self.use_segmentation: contextual_embeddings = self.backbone(tokens, token_types, valid_length) else: contextual_embeddings = self.backbone(tokens, valid_length) start_logits = self.get_start_logits(contextual_embeddings, p_mask) # The shape of start_top_index will be (..., start_top_n) start_top_logits, start_top_index = npx.topk(start_logits, k=start_top_n, axis=-1, ret_typ='both') end_logits = self.get_end_logits(contextual_embeddings, start_top_index, p_mask) # Note that end_top_index and end_top_log_probs have shape (bsz, start_n_top, end_n_top) # So that for each start position, there are end_n_top end positions on the third dim. end_top_logits, end_top_index = npx.topk(end_logits, k=end_top_n, axis=-1, ret_typ='both') answerable_logits = self.get_answerable_logits(contextual_embeddings, p_mask) return start_top_logits, start_top_index, end_top_logits, end_top_index, \ answerable_logits
def forward(self, scores, target_dists, finished, best_hyp_indices): """ Choose an extension of each hypothesis from its softmax distribution. :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size) :param target_dists: The non-cumulative target distributions (ignored). :param finished: The list of finished hypotheses. :param best_hyp_indices: Best hypothesis indices constant. :return: The row indices, column indices, and values of the sampled words. """ # Map the negative logprobs to probabilities so as to have a distribution target_dists = np.exp(-target_dists) # n == 0 means sample from the full vocabulary. Otherwise, we sample from the top n. if self.n != 0: # select the top n in each row, via a mask masked_items = npx.topk(target_dists, k=self.n, ret_typ='mask', axis=1, is_ascend=False) # set unmasked items to 0 masked_items = np.where(masked_items, target_dists, masked_items) # renormalize target_dists = masked_items / np.sum(masked_items, axis=1, keepdims=True) # Sample from the target distributions over words, then get the corresponding values from the cumulative scores best_word_indices = npx.random.categorical(target_dists, get_prob=False) # Zeroes for finished hypotheses. best_word_indices = np.where(finished, np.zeros_like(best_word_indices), best_word_indices) values = npx.pick(scores, best_word_indices, axis=1, keepdims=True) best_hyp_indices = npx.slice_like(best_hyp_indices, best_word_indices, axes=(0,)) return best_hyp_indices, best_word_indices, values
def test_topk(): A = np.ones((2, INT_OVERFLOW)) A[0][100] = 2 A[1][200] = 2 A.attach_grad() with mx.autograd.record(): B = npx.topk(A, k=2) assert B.shape == (2, 2) assert B[0][0] == 100 and B[1][0] == 200 B.backward() assert A.grad.shape == (2, INT_OVERFLOW) assert A.grad[0][0] == 0
def forward(self, scores: np.ndarray, vocab_slice_ids: Optional[np.ndarray] = None, target_factors: Optional[np.ndarray] = None) -> np.ndarray: # shape: (batch*beam=1, 1) # argmin has trouble with fp16 inputs on GPUs, using top1 instead best_word_index = npx.topk(scores, axis=-1, k=1, ret_typ='indices', is_ascend=True, dtype='int32') # Map from restricted to full vocab ids if needed if vocab_slice_ids is not None: best_word_index = np.take(vocab_slice_ids, best_word_index, axis=0) if target_factors is not None: best_word_index = np.concatenate((best_word_index, target_factors), axis=1) return best_word_index
def dynamic_masking(self, input_ids, valid_lengths): # TODO(zheyuye), two additional flag `disallow_from_mask` and `already_masked` # that control the masking status for each positions in the sequence. """ Generate masking positions on-the-fly instead of during preprocessing Parameters ---------- input_ids The batchified input_ids with shape (batch_size, max_seq_length) valid_lengths The batchified valid_lengths with shape (batch_size, ) Returns ------ masked_input_ids The masked input sequence with 15% tokens are masked with [MASK] shape (batch_size, max_seq_length) length_masks The masking matrix for the whole sequence that indicates the positions are greater than valid_length. shape (batch_size, max_seq_length) unmasked_tokens The original tokens that appear in the unmasked input sequence shape (batch_size, num_masked_positions) masked_positions The masking positions in mx.np.ndarray with shape (batch_size, num_masked_positions) shape (batch_size, num_masked_positions) masked_lm_weights The weight matrix containing 0 or 1 to mark the actual effect of masked positions shape (batch_size, num_masked_positions) """ N = self._max_num_masked_position # Only valid token without special token are allowed to mask valid_candidates = np.ones_like(input_ids, dtype=np.bool) ignore_tokens = [ self.vocab.cls_id, self.vocab.sep_id, self.vocab.pad_id ] for ignore_token in ignore_tokens: # TODO(zheyuye), Update when operation += supported valid_candidates = valid_candidates * \ np.not_equal(input_ids, ignore_token) valid_lengths = valid_lengths.astype(np.float32) valid_candidates = valid_candidates.astype(np.float32) num_masked_position = mxnp.maximum( 1, np.minimum(N, round(valid_lengths * self._mask_prob))) # Get the masking probability of each position sample_probs = self._proposal_distribution * valid_candidates sample_probs /= mxnp.sum(sample_probs, axis=-1, keepdims=True) sample_probs = npx.stop_gradient(sample_probs) gumbels = mxnp.random.gumbel(np.zeros_like(sample_probs)) # Following the instruction of official repo to avoid deduplicate postions # with Top_k Sampling as https://github.com/google-research/electra/issues/41 masked_positions = npx.topk(mxnp.log(sample_probs) + gumbels, k=N, axis=-1, ret_typ='indices', dtype=np.int32) masked_weights = npx.sequence_mask(mxnp.ones_like(masked_positions), sequence_length=num_masked_position, use_sequence_length=True, axis=1, value=0) masked_positions = masked_positions * masked_weights length_masks = npx.sequence_mask(mxnp.ones_like(input_ids, dtype=np.float32), sequence_length=valid_lengths, use_sequence_length=True, axis=1, value=0) unmasked_tokens = select_vectors_by_position( input_ids, masked_positions) * masked_weights masked_weights = masked_weights.astype(np.float32) replaced_positions = (mxnp.random.uniform( mxnp.zeros_like(masked_positions), mxnp.ones_like( masked_positions)) < self._replace_prob) * masked_positions # dealing with multiple zero values in replaced_positions which causes # the [CLS] being replaced filled = mxnp.where(replaced_positions, self.vocab.mask_id, self.vocab.cls_id).astype(np.int32) # Masking token by replacing with [MASK] masked_input_ids = update_vectors_by_position(input_ids, filled, replaced_positions) # Note: It is likely have multiple zero values in masked_positions if number of masked of # positions not reached the maximum. However, this example hardly exists since valid_length # is almost always equal to max_seq_length masked_input = self.MaskedInput(input_ids=masked_input_ids, masks=length_masks, unmasked_tokens=unmasked_tokens, masked_positions=masked_positions, masked_weights=masked_weights) return masked_input
def forward(self, scores): values, indices = npx.topk(scores, axis=1, k=self.k, ret_typ='both', is_ascend=True, dtype='int32') # Project indices back into original shape (which is different for t==1 and t>1) values, indices = np.reshape(values, (-1, 1)), np.reshape(indices, (-1,)) return indices, values