Example #1
0
    def inference(self, tokens, token_types, valid_length, p_mask,
                  start_top_n: int = 5, end_top_n: int = 5):
        """Get the inference result with beam search

        Parameters
        ----------
        tokens
            The input tokens. Shape (batch_size, sequence_length)
        token_types
            The input token types. Shape (batch_size, sequence_length)
        valid_length
            The valid length of the tokens. Shape (batch_size,)
        p_mask
            The mask which indicates that some tokens won't be used in the calculation.
            Shape (batch_size, sequence_length)
        start_top_n
            The number of candidates to select for the start position.
        end_top_n
            The number of candidates to select for the end position.

        Returns
        -------
        start_top_logits
            The top start logits
            Shape (batch_size, start_top_n)
        start_top_index
            Index of the top start logits
            Shape (batch_size, start_top_n)
        end_top_logits
            The top end logits.
            Shape (batch_size, start_top_n, end_top_n)
        end_top_index
            Index of the top end logits
            Shape (batch_size, start_top_n, end_top_n)
        answerable_logits
            The answerable logits. Here 0 --> answerable and 1 --> not answerable.
            Shape (batch_size, sequence_length, 2)
        """
        # Shape (batch_size, sequence_length, C)
        if self.use_segmentation:
            contextual_embeddings = self.backbone(tokens, token_types, valid_length)
        else:
            contextual_embeddings = self.backbone(tokens, valid_length)
        start_logits = self.get_start_logits(contextual_embeddings, p_mask)
        # The shape of start_top_index will be (..., start_top_n)
        start_top_logits, start_top_index = npx.topk(start_logits, k=start_top_n, axis=-1,
                                                        ret_typ='both')
        end_logits = self.get_end_logits(contextual_embeddings, start_top_index, p_mask)
        # Note that end_top_index and end_top_log_probs have shape (bsz, start_n_top, end_n_top)
        # So that for each start position, there are end_n_top end positions on the third dim.
        end_top_logits, end_top_index = npx.topk(end_logits, k=end_top_n, axis=-1,
                                                    ret_typ='both')
        answerable_logits = self.get_answerable_logits(contextual_embeddings, p_mask)
        return start_top_logits, start_top_index, end_top_logits, end_top_index, \
                    answerable_logits
Example #2
0
    def forward(self, scores, target_dists, finished, best_hyp_indices):
        """
        Choose an extension of each hypothesis from its softmax distribution.

        :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
        :param target_dists: The non-cumulative target distributions (ignored).
        :param finished: The list of finished hypotheses.
        :param best_hyp_indices: Best hypothesis indices constant.
        :return: The row indices, column indices, and values of the sampled words.
        """
        # Map the negative logprobs to probabilities so as to have a distribution
        target_dists = np.exp(-target_dists)

        # n == 0 means sample from the full vocabulary. Otherwise, we sample from the top n.
        if self.n != 0:
            # select the top n in each row, via a mask
            masked_items = npx.topk(target_dists, k=self.n, ret_typ='mask', axis=1, is_ascend=False)
            # set unmasked items to 0
            masked_items = np.where(masked_items, target_dists, masked_items)
            # renormalize
            target_dists = masked_items / np.sum(masked_items, axis=1, keepdims=True)

        # Sample from the target distributions over words, then get the corresponding values from the cumulative scores
        best_word_indices = npx.random.categorical(target_dists, get_prob=False)
        # Zeroes for finished hypotheses.
        best_word_indices = np.where(finished, np.zeros_like(best_word_indices), best_word_indices)
        values = npx.pick(scores, best_word_indices, axis=1, keepdims=True)

        best_hyp_indices = npx.slice_like(best_hyp_indices, best_word_indices, axes=(0,))

        return best_hyp_indices, best_word_indices, values
def test_topk():
    A = np.ones((2, INT_OVERFLOW))
    A[0][100] = 2
    A[1][200] = 2
    A.attach_grad()
    with mx.autograd.record():
        B = npx.topk(A, k=2)
    assert B.shape == (2, 2)
    assert B[0][0] == 100 and B[1][0] == 200
    B.backward()
    assert A.grad.shape == (2, INT_OVERFLOW)
    assert A.grad[0][0] == 0
Example #4
0
 def forward(self,
             scores: np.ndarray,
             vocab_slice_ids: Optional[np.ndarray] = None,
             target_factors: Optional[np.ndarray] = None) -> np.ndarray:
     # shape: (batch*beam=1, 1)
     # argmin has trouble with fp16 inputs on GPUs, using top1 instead
     best_word_index = npx.topk(scores, axis=-1, k=1, ret_typ='indices', is_ascend=True, dtype='int32')
     # Map from restricted to full vocab ids if needed
     if vocab_slice_ids is not None:
         best_word_index = np.take(vocab_slice_ids, best_word_index, axis=0)
     if target_factors is not None:
         best_word_index = np.concatenate((best_word_index, target_factors), axis=1)
     return best_word_index
Example #5
0
    def dynamic_masking(self, input_ids, valid_lengths):
        # TODO(zheyuye), two additional flag `disallow_from_mask` and `already_masked`
        # that control the masking status for each positions in the sequence.
        """
        Generate masking positions on-the-fly instead of during preprocessing
        Parameters
        ----------
        input_ids
            The batchified input_ids with shape (batch_size, max_seq_length)
        valid_lengths
            The batchified valid_lengths with shape (batch_size, )
        Returns
        ------
        masked_input_ids
            The masked input sequence with 15% tokens are masked with [MASK]
            shape (batch_size, max_seq_length)
        length_masks
            The masking matrix for the whole sequence that indicates the positions
            are greater than valid_length.

            shape (batch_size, max_seq_length)
        unmasked_tokens
            The original tokens that appear in the unmasked input sequence
            shape (batch_size, num_masked_positions)
        masked_positions
            The masking positions in mx.np.ndarray with shape (batch_size, num_masked_positions)
            shape (batch_size, num_masked_positions)
        masked_lm_weights
            The weight matrix containing 0 or 1 to mark the actual effect of masked positions
            shape (batch_size, num_masked_positions)
        """
        N = self._max_num_masked_position
        # Only valid token without special token are allowed to mask
        valid_candidates = np.ones_like(input_ids, dtype=np.bool)
        ignore_tokens = [
            self.vocab.cls_id, self.vocab.sep_id, self.vocab.pad_id
        ]

        for ignore_token in ignore_tokens:
            # TODO(zheyuye), Update when operation += supported
            valid_candidates = valid_candidates * \
                np.not_equal(input_ids, ignore_token)
        valid_lengths = valid_lengths.astype(np.float32)
        valid_candidates = valid_candidates.astype(np.float32)
        num_masked_position = mxnp.maximum(
            1, np.minimum(N, round(valid_lengths * self._mask_prob)))

        # Get the masking probability of each position
        sample_probs = self._proposal_distribution * valid_candidates
        sample_probs /= mxnp.sum(sample_probs, axis=-1, keepdims=True)
        sample_probs = npx.stop_gradient(sample_probs)
        gumbels = mxnp.random.gumbel(np.zeros_like(sample_probs))
        # Following the instruction of official repo to avoid deduplicate postions
        # with Top_k Sampling as https://github.com/google-research/electra/issues/41
        masked_positions = npx.topk(mxnp.log(sample_probs) + gumbels,
                                    k=N,
                                    axis=-1,
                                    ret_typ='indices',
                                    dtype=np.int32)

        masked_weights = npx.sequence_mask(mxnp.ones_like(masked_positions),
                                           sequence_length=num_masked_position,
                                           use_sequence_length=True,
                                           axis=1,
                                           value=0)
        masked_positions = masked_positions * masked_weights
        length_masks = npx.sequence_mask(mxnp.ones_like(input_ids,
                                                        dtype=np.float32),
                                         sequence_length=valid_lengths,
                                         use_sequence_length=True,
                                         axis=1,
                                         value=0)
        unmasked_tokens = select_vectors_by_position(
            input_ids, masked_positions) * masked_weights
        masked_weights = masked_weights.astype(np.float32)
        replaced_positions = (mxnp.random.uniform(
            mxnp.zeros_like(masked_positions), mxnp.ones_like(
                masked_positions)) < self._replace_prob) * masked_positions
        # dealing with multiple zero values in replaced_positions which causes
        # the [CLS] being replaced
        filled = mxnp.where(replaced_positions, self.vocab.mask_id,
                            self.vocab.cls_id).astype(np.int32)
        # Masking token by replacing with [MASK]
        masked_input_ids = update_vectors_by_position(input_ids, filled,
                                                      replaced_positions)

        # Note: It is likely have multiple zero values in masked_positions if number of masked of
        # positions not reached the maximum. However, this example hardly exists since valid_length
        # is almost always equal to max_seq_length
        masked_input = self.MaskedInput(input_ids=masked_input_ids,
                                        masks=length_masks,
                                        unmasked_tokens=unmasked_tokens,
                                        masked_positions=masked_positions,
                                        masked_weights=masked_weights)
        return masked_input
Example #6
0
 def forward(self, scores):
     values, indices = npx.topk(scores, axis=1, k=self.k, ret_typ='both', is_ascend=True, dtype='int32')
     # Project indices back into original shape (which is different for t==1 and t>1)
     values, indices = np.reshape(values, (-1, 1)), np.reshape(indices, (-1,))
     return indices, values