예제 #1
0
    def get_end_logits(self, contextual_embedding, start_positions, p_mask):
        """

        Parameters
        ----------
        contextual_embedding
            Shape (batch_size, sequence_length, C)
        start_positions
            Shape (batch_size, N)
            We process multiple candidates simultaneously
        p_mask
            Shape (batch_size, sequence_length)

        Returns
        -------
        end_logits
            Shape (batch_size, N, sequence_length)
        """
        # Select the features at the start_positions
        # start_feature will have shape (batch_size, N, C)
        start_features = select_vectors_by_position(contextual_embedding, start_positions)
        # Concatenate the start_feature and the contextual_embedding
        contextual_embedding = np.expand_dims(contextual_embedding, axis=1)  # (B, 1, T, C)
        start_features = np.expand_dims(start_features, axis=2)  # (B, N, 1, C)
        concat_features = np.concatenate([npx.broadcast_like(start_features,
                                                                 contextual_embedding, 2, 2),
                                            npx.broadcast_like(contextual_embedding,
                                                                 start_features, 1, 1)],
                                           axis=-1)  # (B, N, T, 2C)
        end_scores = self.end_scores(concat_features)
        end_scores = np.squeeze(end_scores, -1)
        end_logits = masked_logsoftmax(end_scores, mask=np.expand_dims(p_mask, axis=1),
                                       axis=-1)
        return end_logits
예제 #2
0
    def dynamic_masking(self, input_ids, valid_lengths):
        # TODO(zheyuye), two additional flag `disallow_from_mask` and `already_masked`
        # that control the masking status for each positions in the sequence.
        """
        Generate masking positions on-the-fly instead of during preprocessing
        Parameters
        ----------
        input_ids
            The batchified input_ids with shape (batch_size, max_seq_length)
        valid_lengths
            The batchified valid_lengths with shape (batch_size, )
        Returns
        ------
        masked_input_ids
            The masked input sequence with 15% tokens are masked with [MASK]
            shape (batch_size, max_seq_length)
        length_masks
            The masking matrix for the whole sequence that indicates the positions
            are greater than valid_length.

            shape (batch_size, max_seq_length)
        unmasked_tokens
            The original tokens that appear in the unmasked input sequence
            shape (batch_size, num_masked_positions)
        masked_positions
            The masking positions in mx.np.ndarray with shape (batch_size, num_masked_positions)
            shape (batch_size, num_masked_positions)
        masked_lm_weights
            The weight matrix containing 0 or 1 to mark the actual effect of masked positions
            shape (batch_size, num_masked_positions)
        """
        N = self._max_num_masked_position
        # Only valid token without special token are allowed to mask
        valid_candidates = F.np.ones_like(input_ids, dtype=np.bool)
        ignore_tokens = [
            self.vocab.cls_id, self.vocab.sep_id, self.vocab.pad_id
        ]

        for ignore_token in ignore_tokens:
            # TODO(zheyuye), Update when operation += supported
            valid_candidates = valid_candidates * \
                F.np.not_equal(input_ids, ignore_token)
        valid_lengths = valid_lengths.astype(np.float32)
        valid_candidates = valid_candidates.astype(np.float32)
        num_masked_position = F.np.maximum(
            1, F.np.minimum(N, round(valid_lengths * self._mask_prob)))

        # Get the masking probability of each position
        sample_probs = self._proposal_distribution * valid_candidates
        sample_probs /= F.np.sum(sample_probs, axis=-1, keepdims=True)
        sample_probs = F.npx.stop_gradient(sample_probs)
        gumbels = F.np.random.gumbel(F.np.zeros_like(sample_probs))
        # Following the instruction of official repo to avoid deduplicate postions
        # with Top_k Sampling as https://github.com/google-research/electra/issues/41
        masked_positions = F.npx.topk(F.np.log(sample_probs) + gumbels,
                                      k=N,
                                      axis=-1,
                                      ret_typ='indices',
                                      dtype=np.int32)

        masked_weights = F.npx.sequence_mask(
            F.np.ones_like(masked_positions),
            sequence_length=num_masked_position,
            use_sequence_length=True,
            axis=1,
            value=0)
        masked_positions = masked_positions * masked_weights
        length_masks = F.npx.sequence_mask(F.np.ones_like(input_ids,
                                                          dtype=np.float32),
                                           sequence_length=valid_lengths,
                                           use_sequence_length=True,
                                           axis=1,
                                           value=0)
        unmasked_tokens = select_vectors_by_position(
            input_ids, masked_positions) * masked_weights
        masked_weights = masked_weights.astype(np.float32)
        replaced_positions = (F.np.random.uniform(
            F.np.zeros_like(masked_positions), F.np.ones_like(
                masked_positions)) < self._replace_prob) * masked_positions
        # dealing with multiple zero values in replaced_positions which causes
        # the [CLS] being replaced
        filled = F.np.where(replaced_positions, self.vocab.mask_id,
                            self.vocab.cls_id).astype(np.int32)
        # Masking token by replacing with [MASK]
        masked_input_ids = update_vectors_by_position(input_ids, filled,
                                                      replaced_positions)

        # Note: It is likely have multiple zero values in masked_positions if number of masked of
        # positions not reached the maximum. However, this example hardly exists since valid_length
        # is almost always equal to max_seq_length
        masked_input = self.MaskedInput(input_ids=masked_input_ids,
                                        masks=length_masks,
                                        unmasked_tokens=unmasked_tokens,
                                        masked_positions=masked_positions,
                                        masked_weights=masked_weights)
        return masked_input