def get_end_logits(self, contextual_embedding, start_positions, p_mask): """ Parameters ---------- contextual_embedding Shape (batch_size, sequence_length, C) start_positions Shape (batch_size, N) We process multiple candidates simultaneously p_mask Shape (batch_size, sequence_length) Returns ------- end_logits Shape (batch_size, N, sequence_length) """ # Select the features at the start_positions # start_feature will have shape (batch_size, N, C) start_features = select_vectors_by_position(contextual_embedding, start_positions) # Concatenate the start_feature and the contextual_embedding contextual_embedding = np.expand_dims(contextual_embedding, axis=1) # (B, 1, T, C) start_features = np.expand_dims(start_features, axis=2) # (B, N, 1, C) concat_features = np.concatenate([npx.broadcast_like(start_features, contextual_embedding, 2, 2), npx.broadcast_like(contextual_embedding, start_features, 1, 1)], axis=-1) # (B, N, T, 2C) end_scores = self.end_scores(concat_features) end_scores = np.squeeze(end_scores, -1) end_logits = masked_logsoftmax(end_scores, mask=np.expand_dims(p_mask, axis=1), axis=-1) return end_logits
def dynamic_masking(self, input_ids, valid_lengths): # TODO(zheyuye), two additional flag `disallow_from_mask` and `already_masked` # that control the masking status for each positions in the sequence. """ Generate masking positions on-the-fly instead of during preprocessing Parameters ---------- input_ids The batchified input_ids with shape (batch_size, max_seq_length) valid_lengths The batchified valid_lengths with shape (batch_size, ) Returns ------ masked_input_ids The masked input sequence with 15% tokens are masked with [MASK] shape (batch_size, max_seq_length) length_masks The masking matrix for the whole sequence that indicates the positions are greater than valid_length. shape (batch_size, max_seq_length) unmasked_tokens The original tokens that appear in the unmasked input sequence shape (batch_size, num_masked_positions) masked_positions The masking positions in mx.np.ndarray with shape (batch_size, num_masked_positions) shape (batch_size, num_masked_positions) masked_lm_weights The weight matrix containing 0 or 1 to mark the actual effect of masked positions shape (batch_size, num_masked_positions) """ N = self._max_num_masked_position # Only valid token without special token are allowed to mask valid_candidates = F.np.ones_like(input_ids, dtype=np.bool) ignore_tokens = [ self.vocab.cls_id, self.vocab.sep_id, self.vocab.pad_id ] for ignore_token in ignore_tokens: # TODO(zheyuye), Update when operation += supported valid_candidates = valid_candidates * \ F.np.not_equal(input_ids, ignore_token) valid_lengths = valid_lengths.astype(np.float32) valid_candidates = valid_candidates.astype(np.float32) num_masked_position = F.np.maximum( 1, F.np.minimum(N, round(valid_lengths * self._mask_prob))) # Get the masking probability of each position sample_probs = self._proposal_distribution * valid_candidates sample_probs /= F.np.sum(sample_probs, axis=-1, keepdims=True) sample_probs = F.npx.stop_gradient(sample_probs) gumbels = F.np.random.gumbel(F.np.zeros_like(sample_probs)) # Following the instruction of official repo to avoid deduplicate postions # with Top_k Sampling as https://github.com/google-research/electra/issues/41 masked_positions = F.npx.topk(F.np.log(sample_probs) + gumbels, k=N, axis=-1, ret_typ='indices', dtype=np.int32) masked_weights = F.npx.sequence_mask( F.np.ones_like(masked_positions), sequence_length=num_masked_position, use_sequence_length=True, axis=1, value=0) masked_positions = masked_positions * masked_weights length_masks = F.npx.sequence_mask(F.np.ones_like(input_ids, dtype=np.float32), sequence_length=valid_lengths, use_sequence_length=True, axis=1, value=0) unmasked_tokens = select_vectors_by_position( input_ids, masked_positions) * masked_weights masked_weights = masked_weights.astype(np.float32) replaced_positions = (F.np.random.uniform( F.np.zeros_like(masked_positions), F.np.ones_like( masked_positions)) < self._replace_prob) * masked_positions # dealing with multiple zero values in replaced_positions which causes # the [CLS] being replaced filled = F.np.where(replaced_positions, self.vocab.mask_id, self.vocab.cls_id).astype(np.int32) # Masking token by replacing with [MASK] masked_input_ids = update_vectors_by_position(input_ids, filled, replaced_positions) # Note: It is likely have multiple zero values in masked_positions if number of masked of # positions not reached the maximum. However, this example hardly exists since valid_length # is almost always equal to max_seq_length masked_input = self.MaskedInput(input_ids=masked_input_ids, masks=length_masks, unmasked_tokens=unmasked_tokens, masked_positions=masked_positions, masked_weights=masked_weights) return masked_input