Exemplo n.º 1
0
    def forward(self, data, steps):  # pylint: disable=arguments-differ
        """
        Applies positional embeddings to input data.

        :param data: Input data. Shape: (batch, length or 1, num_embed)
        :param steps: Optional steps input. If given, shape is (batch_size or 1, seq_len,)

        :return: Data with positional embeddings added
        """
        # (length, num_embed)
        if steps is None:
            # (batch, length, num_embed)
            pos_embedding = npx.slice_like(np.expand_dims(self.weight.data(),
                                                          axis=0),
                                           data,
                                           axes=(1, ))
        else:
            # (batch_size or 1, seq_len, num_embed)
            pos_embedding = npx.embedding(steps, self.weight.data(),
                                          self.max_seq_len, self.num_embed)

        if self.weight_type == 'fixed':
            pos_embedding = npx.stop_gradient(pos_embedding)

        if self.scale_up_input:
            data = data * (self.num_embed**0.5)

        return data + pos_embedding
Exemplo n.º 2
0
def gumbel_softmax(logits,
                   temperature: float = 1.0,
                   eps: float = 1E-10,
                   hard=True,
                   use_np_gumbel: bool = True):
    r"""Perform the gumbel-softmax trick to generate differentiable one-hot vectors from the input
    logits.

    Here, the gumbel distribution is

    Gumbel(\alpha) = -log (-log U) + \log \alpha, in which U is the uniform(0, 1) distribution.

    A nice property of Gumbel is:

    \argmax({Gumbel(\alpha_i)}) \sim multinomial(\alpha_i)

    The Gumbel-Softmax trick is to use the softmax + straight-through estimator to produce
    one-hot vectors that represent the sampling result.

    References:

        1. https://en.wikipedia.org/wiki/Gumbel_distribution
        2. [ICLR2017] Categorical Reparameterization with Gumbel-Softmax

    Parameters
    ----------
    logits
        Logits. Shape (..., V)
    temperature
        The temperature that controls the
    eps
        The eps for stability of gradient
    hard
        Whether to use the straight-through estimator to produce one-hot vectors.
    use_np_gumbel
        Whether to use the random.gumble operator

    Returns
    -------
    ret
        The returned output. Shape (..., V)
    """
    # TODO(sxjscience) Investigate the impact of random.gumbel:
    #  Actually, random.gumble has no eps and may have problem in calculating the gradient.
    if use_np_gumbel:
        gumbels = np.random.gumbel(np.zeros_like(logits))
    else:
        u = np.random.uniform(np.zeros_like(logits), 1)
        gumbels = -np.log(-np.log(u + eps) + eps)
    y = npx.softmax((gumbels + logits) / temperature, axis=-1)
    if hard:
        y_hard = np.max(y, axis=-1, keepdims=True) == y
        y_hard = npx.stop_gradient(y_hard - y) + y
        return y_hard
    else:
        return y
Exemplo n.º 3
0
 def forward(self, x: np.ndarray) -> np.ndarray:
     # Shape: (length, 1)
     length_array = npx.arange_like(x, axis=1)
     # matrix with lower triangle and main diagonal set to 0, upper triangle set to 1
     # Shape: (length, length)
     bias = npx.broadcast_greater(np.expand_dims(length_array, axis=0),
                                  np.expand_dims(length_array, axis=1))
     bias = bias * -C.LARGE_VALUES[self._dtype]
     bias = np.expand_dims(bias, axis=0)
     return npx.stop_gradient(bias)
Exemplo n.º 4
0
def test_stop_gradient():
    A = np.ones((INT_OVERFLOW, 2))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.stop_gradient(A * 3)
    assert B.shape == (INT_OVERFLOW, 2)
    assert B[0][0] == 3
    B.backward()
    # should be 3 if not for stop_gradient()
    assert A.grad[0][0] == 0
Exemplo n.º 5
0
    def dynamic_masking(self, input_ids, valid_lengths):
        # TODO(zheyuye), two additional flag `disallow_from_mask` and `already_masked`
        # that control the masking status for each positions in the sequence.
        """
        Generate masking positions on-the-fly instead of during preprocessing
        Parameters
        ----------
        input_ids
            The batchified input_ids with shape (batch_size, max_seq_length)
        valid_lengths
            The batchified valid_lengths with shape (batch_size, )
        Returns
        ------
        masked_input_ids
            The masked input sequence with 15% tokens are masked with [MASK]
            shape (batch_size, max_seq_length)
        length_masks
            The masking matrix for the whole sequence that indicates the positions
            are greater than valid_length.

            shape (batch_size, max_seq_length)
        unmasked_tokens
            The original tokens that appear in the unmasked input sequence
            shape (batch_size, num_masked_positions)
        masked_positions
            The masking positions in mx.np.ndarray with shape (batch_size, num_masked_positions)
            shape (batch_size, num_masked_positions)
        masked_lm_weights
            The weight matrix containing 0 or 1 to mark the actual effect of masked positions
            shape (batch_size, num_masked_positions)
        """
        N = self._max_num_masked_position
        # Only valid token without special token are allowed to mask
        valid_candidates = np.ones_like(input_ids, dtype=np.bool)
        ignore_tokens = [
            self.vocab.cls_id, self.vocab.sep_id, self.vocab.pad_id
        ]

        for ignore_token in ignore_tokens:
            # TODO(zheyuye), Update when operation += supported
            valid_candidates = valid_candidates * \
                np.not_equal(input_ids, ignore_token)
        valid_lengths = valid_lengths.astype(np.float32)
        valid_candidates = valid_candidates.astype(np.float32)
        num_masked_position = mxnp.maximum(
            1, np.minimum(N, round(valid_lengths * self._mask_prob)))

        # Get the masking probability of each position
        sample_probs = self._proposal_distribution * valid_candidates
        sample_probs /= mxnp.sum(sample_probs, axis=-1, keepdims=True)
        sample_probs = npx.stop_gradient(sample_probs)
        gumbels = mxnp.random.gumbel(np.zeros_like(sample_probs))
        # Following the instruction of official repo to avoid deduplicate postions
        # with Top_k Sampling as https://github.com/google-research/electra/issues/41
        masked_positions = npx.topk(mxnp.log(sample_probs) + gumbels,
                                    k=N,
                                    axis=-1,
                                    ret_typ='indices',
                                    dtype=np.int32)

        masked_weights = npx.sequence_mask(mxnp.ones_like(masked_positions),
                                           sequence_length=num_masked_position,
                                           use_sequence_length=True,
                                           axis=1,
                                           value=0)
        masked_positions = masked_positions * masked_weights
        length_masks = npx.sequence_mask(mxnp.ones_like(input_ids,
                                                        dtype=np.float32),
                                         sequence_length=valid_lengths,
                                         use_sequence_length=True,
                                         axis=1,
                                         value=0)
        unmasked_tokens = select_vectors_by_position(
            input_ids, masked_positions) * masked_weights
        masked_weights = masked_weights.astype(np.float32)
        replaced_positions = (mxnp.random.uniform(
            mxnp.zeros_like(masked_positions), mxnp.ones_like(
                masked_positions)) < self._replace_prob) * masked_positions
        # dealing with multiple zero values in replaced_positions which causes
        # the [CLS] being replaced
        filled = mxnp.where(replaced_positions, self.vocab.mask_id,
                            self.vocab.cls_id).astype(np.int32)
        # Masking token by replacing with [MASK]
        masked_input_ids = update_vectors_by_position(input_ids, filled,
                                                      replaced_positions)

        # Note: It is likely have multiple zero values in masked_positions if number of masked of
        # positions not reached the maximum. However, this example hardly exists since valid_length
        # is almost always equal to max_seq_length
        masked_input = self.MaskedInput(input_ids=masked_input_ids,
                                        masks=length_masks,
                                        unmasked_tokens=unmasked_tokens,
                                        masked_positions=masked_positions,
                                        masked_weights=masked_weights)
        return masked_input