def __call__(self, X, state=None):
     X = npx.one_hot(X.T, self.vocab_size
                     )  # return shape: (num_steps, batch_size,vocab_size)
     if state is None:
         batch_size = X.shape[1]
         state = self.init_state_h0(batch_size)
     return self.forward_fn(X, state, self.params)
Beispiel #2
0
 def forward(self, inputs, state):
     X = npx.one_hot(inputs.T, self.vocab_size)
     Y, state = self.rnn(X, state)
     # 全连接层首先将`Y`的形状改为(`时间步数` * `批量大小`, `隐藏单元数`)。
     # 它的输出形状是 (`时间步数` * `批量大小`, `词表大小`)。
     output = self.dense(Y.reshape(-1, Y.shape[-1]))
     return output, state
Beispiel #3
0
 def forward(self, inputs, state):
     X = npx.one_hot(inputs.T, self.vocab_size)
     Y, state = self.rnn(X, state)
     # The fully-connected layer will first change the shape of `Y` to
     # (`num_steps` * `batch_size`, `num_hiddens`). Its output shape is
     # (`num_steps` * `batch_size`, `vocab_size`).
     output = self.dense(Y.reshape(-1, Y.shape[-1]))
     return output, state
def test_one_hot():
    A = np.zeros((INT_OVERFLOW))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.one_hot(A, 2)
    assert B.shape == (INT_OVERFLOW, 2)
    assert B[0][0] == 1
    B.backward()
    assert A.grad.shape == (INT_OVERFLOW, )
    assert A.grad[0] == 0
Beispiel #5
0
    def _process_input(self, x, state, ctx):
        assert len(x.shape) == 2, f"x输入的要求是2D,但实际的x.shape={x.shape}"
        batch_size = x.shape[0]
        x = x.T
        x = npx.one_hot(x, self.num_inputs)

        if state is None:
            state = self.begin_state(batch_size, ctx)
        x = x.as_in_ctx(ctx)
        return x, state
Beispiel #6
0
 def __call__(self, X, state):
     X = npx.one_hot(X.T, self.vocab_size)
     return self.forward_fn(X, state, self.params)
Beispiel #7
0
    def get_corrupted_tokens(self, inputs, original_tokens, masked_positions, logits):
        """
        Sample from the generator to create corrupted input.

        Parameters
        ----------
        F
        inputs
            The masked input
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)
        original_tokens
            The original tokens that appear in the unmasked input sequence
            Shape (batch_size, num_masked_positions).
        masked_positions
            The masked position of the sequence
            Shape (batch_size, num_masked_positions).
        logits
            The logits of each tokens
            Shape (batch_size, num_masked_positions, vocab_size)

        Returns
        -------
        corrupted_tokens
            Shape (batch_size, )
        fake_data
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)
        labels
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)
        """

        if self._disallow_correct:
            # TODO(sxjscience), Revise the implementation
            disallow = npx.one_hot(masked_positions, depth=self.vocab_size, dtype=self._dtype)
            logits = logits - 1000.0 * disallow
        # gumbel_softmax() samples from the logits with a noise of Gumbel distribution
        prob = gumbel_softmax(
            F,
            logits,
            temperature=self._temperature,
            eps=self._gumbel_eps,
            use_np_gumbel=False)
        corrupted_tokens = np.argmax(prob, axis=-1).astype(np.int32)

        if self.disc_backbone.layout == 'TN':
            inputs = inputs.T
        original_data = update_vectors_by_position(F,
            inputs, original_tokens, masked_positions)
        fake_data = update_vectors_by_position(F,
            inputs, corrupted_tokens, masked_positions)
        updates_mask = add_vectors_by_position(np.zeros_like(inputs),
                np.ones_like(masked_positions), masked_positions)
        # Dealing with multiple zeros in masked_positions which
        # results in a non-zero value in the first index [CLS]
        updates_mask = np.minimum(updates_mask, 1)
        labels = updates_mask * np.not_equal(fake_data, original_data)
        if self.disc_backbone.layout == 'TN':
            return corrupted_tokens, fake_data.T, labels.T
        else:
            return corrupted_tokens, fake_data, labels