예제 #1
0
    def forward(self, inputs, begin_state=None): # pylint: disable=arguments-differ
        """Implement forward computation.

        Parameters
        ----------
        inputs : NDArray
            The training dataset.
        begin_state : list
            The initial hidden states.

        Returns
        -------
        out: NDArray
            The output of the model.
        out_states: list
            The list of output states of the model's encoder.
        """
        encoded = self.embedding(inputs)
        if not begin_state:
            begin_state = self.begin_state(batch_size=inputs.shape[1])
        out_states = []
        for i, (e, s) in enumerate(zip(self.encoder, begin_state)):
            encoded, state = e(encoded, s)
            out_states.append(state)
            if self._drop_h and i != len(self.encoder)-1:
                encoded = nd.Dropout(encoded, p=self._drop_h, axes=(0,))
        if self._dropout:
            encoded = nd.Dropout(encoded, p=self._dropout, axes=(0,))
        with autograd.predict_mode():
            out = self.decoder(encoded)
        return out, out_states
예제 #2
0
    def forward(self, inputs, begin_state=None):  # pylint: disable=arguments-differ
        """Implement forward computation.

        Parameters
        -----------
        inputs : NDArray
            input tensor with shape `(sequence_length, batch_size)`
            when `layout` is "TNC".
        begin_state : list
            initial recurrent state tensor with length equals to num_layers.
            the initial state with shape `(1, batch_size, num_hidden)`

        Returns
        --------
        out: NDArray
            output tensor with shape `(sequence_length, batch_size, input_size)`
            when `layout` is "TNC".
        out_states: list
            output recurrent state tensor with length equals to num_layers.
            the state with shape `(1, batch_size, num_hidden)`
        """
        encoded = self.embedding(inputs)
        if not begin_state:
            begin_state = self.begin_state(batch_size=inputs.shape[1])
        out_states = []
        for i, (e, s) in enumerate(zip(self.encoder, begin_state)):
            encoded, state = e(encoded, s)
            out_states.append(state)
            if self._drop_h and i != len(self.encoder) - 1:
                encoded = nd.Dropout(encoded, p=self._drop_h, axes=(0, ))
        if self._dropout:
            encoded = nd.Dropout(encoded, p=self._dropout, axes=(0, ))
        with autograd.predict_mode():
            out = self.decoder(encoded)
        return out, out_states
예제 #3
0
    def forward(self, inputs, begin_state=None): # pylint: disable=arguments-differ
        """Implement the forward computation that the awd language model and cache model use.

        Parameters
        -----------
        inputs : NDArray
            input tensor with shape `(sequence_length, batch_size)`
            when `layout` is "TNC".
        begin_state : list
            initial recurrent state tensor with length equals to num_layers.
            the initial state with shape `(1, batch_size, num_hidden)`

        Returns
        --------
        out: NDArray
            output tensor with shape `(sequence_length, batch_size, input_size)`
            when `layout` is "TNC".
        out_states: list
            output recurrent state tensor with length equals to num_layers.
            the state with shape `(1, batch_size, num_hidden)`
        encoded_raw: list
            The list of outputs of the model's encoder with length equals to num_layers.
            the shape of every encoder's output `(sequence_length, batch_size, num_hidden)`
        encoded_dropped: list
            The list of outputs with dropout of the model's encoder with length equals
            to num_layers. The shape of every encoder's dropped output
            `(sequence_length, batch_size, num_hidden)`
        """
        encoded = self.embedding(inputs)
        if not begin_state:
            begin_state = self.begin_state(batch_size=inputs.shape[1])
        out_states = []
        encoded_raw = []
        encoded_dropped = []
        for i, (e, s) in enumerate(zip(self.encoder, begin_state)):
            encoded, state = e(encoded, s)
            encoded_raw.append(encoded)
            out_states.append(state)
            if self._drop_h and i != len(self.encoder)-1:
                encoded = nd.Dropout(encoded, p=self._drop_h, axes=(0,))
                encoded_dropped.append(encoded)
        if self._dropout:
            encoded = nd.Dropout(encoded, p=self._dropout, axes=(0,))
        encoded_dropped.append(encoded)
        latent = nd.Dropout(self.latent(encoded), p=self._drop_l, axes=(0,))
        logit = self.decoder(latent.reshape(-1, self._embed_size))
        prior_logit = self.prior(encoded).reshape(-1, self._num_experts)
        prior = nd.softmax(prior_logit)
        prob = nd.softmax(logit.reshape(-1, self._vocab_size))
        prob = prob.reshape(-1, self._num_experts, self._vocab_size)
        prob = (prob * prior.expand_dims(2).broadcast_to(prob.shape)).sum(axis=1)
        out = nd.log(nd.add(prob, 1e-8)).reshape(-1, inputs.shape[1], self._vocab_size)
        return out, out_states, encoded_raw, encoded_dropped
예제 #4
0
 def mlp(self, top_recur):
     is_train = autograd.is_training()
     if is_train:
         top_recur = nd.Dropout(data=top_recur, axes=[0], p=self.dropout_mlp)
     W_dep, b_dep = self.mlp_dep_W.data(), self.mlp_dep_b.data()
     W_head, b_head = self.mlp_head_W.data(), self.mlp_head_b.data()
     dep, head = leaky_relu(nd.dot(top_recur, W_dep.T) + b_dep), leaky_relu(nd.dot(top_recur, W_head.T) + b_head)
     if is_train:
         dep, head = nd.Dropout(data=dep, axes=[0], p=self.dropout_mlp), nd.Dropout(data=head, axes=[0],
                                                                                    p=self.dropout_mlp)
     dep, head = nd.transpose(dep, axes=[2, 0, 1]), nd.transpose(head, axes=[2, 0, 1])
     dep_arc, dep_rel = dep[:self.mlp_arc_size], dep[self.mlp_arc_size:]
     head_arc, head_rel = head[:self.mlp_arc_size], head[self.mlp_arc_size:]
     return dep_arc, dep_rel, head_arc, head_rel
예제 #5
0
    def forward(self, inputs, begin_state=None):
        """Implement forward computation.

        Parameters
        ----------
        inputs : NDArray
            The training dataset.
        begin_state : list
            The initial hidden states.

        Returns
        -------
        out: NDArray
            The output of the model.
        out_states: list
            The list of output states of the model's encoder.
        """
        encoded = self.embedding(inputs)
        if not begin_state:
            begin_state = self.begin_state(batch_size=inputs.shape[1])
        out_states = []
        encoded_raw = []
        encoded_dropped = []
        for i, (e, s) in enumerate(zip(self.encoder, begin_state)):
            encoded, state = e(encoded, s)
            encoded_raw.append(encoded)
            out_states.append(state)
            if self._drop_h and i != len(self.encoder) - 1:
                encoded = nd.Dropout(encoded, p=self._drop_h, axes=(0, ))
                encoded_dropped.append(encoded)
        if self._dropout:
            encoded = nd.Dropout(encoded, p=self._dropout, axes=(0, ))
        states = out_states
        encoded_dropped.append(encoded)

        latent = nd.Dropout(self.latent(encoded), p=self._drop_l, axes=(0, ))
        logit = self.decoder(latent.reshape(-1, self._embed_size))
        prior_logit = self.prior(encoded).reshape(-1, self._num_experts)
        prior = nd.softmax(prior_logit)

        prob = nd.softmax(logit.reshape(-1, self._vocab_size))
        prob = prob.reshape(-1, self._num_experts, self._vocab_size)
        prob = (prob *
                prior.expand_dims(2).broadcast_to(prob.shape)).sum(axis=1)
        out = nd.log(nd.add(prob, 1e-8)).reshape(-1, inputs.shape[1],
                                                 self._vocab_size)

        return out, out_states, encoded_raw, encoded_dropped
예제 #6
0
    def forward(self, inputs, begin_state=None):  # pylint: disable=arguments-differ
        """Defines the forward computation. Arguments can be either
        :py:class:`NDArray` or :py:class:`Symbol`.

        Parameters
        -----------
        inputs : NDArray
            input tensor with shape `(sequence_length, batch_size)`
              when `layout` is "TNC".
        begin_state : list
            initial recurrent state tensor with length equals to num_layers-1.
            the initial state with shape `(num_layers, batch_size, num_hidden)`

        Returns
        --------
        out: NDArray
            output tensor with shape `(sequence_length, batch_size, input_size)`
              when `layout` is "TNC".
        out_states: list
            output recurrent state tensor with length equals to num_layers-1.
            the state with shape `(num_layers, batch_size, num_hidden)`
        """
        encoded = self.embedding(inputs)
        if not begin_state:
            begin_state = self.begin_state(batch_size=inputs.shape[1])
        encoded, state = self.encoder(encoded, begin_state)
        if self._dropout:
            encoded = nd.Dropout(encoded, p=self._dropout, axes=(0, ))
        out = self.decoder(encoded)
        return out, state
예제 #7
0
def biLSTM(f_lstm, b_lstm, inputs, dropout_x=0.):
    """Feature extraction through BiLSTM

    Parameters
    ----------
    f_lstm : VariationalDropoutCell
        Forward cell
    b_lstm : VariationalDropoutCell
        Backward cell
    inputs : NDArray
        seq_len x batch_size
    dropout_x : float
        Variational dropout on inputs

    Returns
    -------
    outputs : NDArray
        Outputs of BiLSTM layers, seq_len x 2 hidden_dims x batch_size
    """
    for f, b in zip(f_lstm, b_lstm):
        inputs = nd.Dropout(inputs, dropout_x, axes=[0])  # important for variational dropout
        fo, _ = f.unroll(length=inputs.shape[0], inputs=inputs, layout='TNC', merge_outputs=True)
        bo, _ = b.unroll(length=inputs.shape[0], inputs=inputs.flip(axis=0), layout='TNC',
                         merge_outputs=True)
        f.reset()
        b.reset()
        inputs = nd.concat(fo, bo.flip(axis=0), dim=2)
    return inputs
예제 #8
0
    def forward(self, inputs, begin_state=None): # pylint: disable=arguments-differ
        """Implement the forward computation that the awd language model and cache model use.

        Parameters
        -----------
        inputs : NDArray
            input tensor with shape `(sequence_length, batch_size)`
            when `layout` is "TNC".
        begin_state : list
            initial recurrent state tensor with length equals to num_layers.
            the initial state with shape `(1, batch_size, num_hidden)`

        Returns
        --------
        out: NDArray
            output tensor with shape `(sequence_length, batch_size, input_size)`
            when `layout` is "TNC".
        out_states: list
            output recurrent state tensor with length equals to num_layers.
            the state with shape `(1, batch_size, num_hidden)`
        encoded_raw: list
            The list of outputs of the model's encoder with length equals to num_layers.
            the shape of every encoder's output `(sequence_length, batch_size, num_hidden)`
        encoded_dropped: list
            The list of outputs with dropout of the model's encoder with length equals
            to num_layers. The shape of every encoder's dropped output
            `(sequence_length, batch_size, num_hidden)`
        """
        encoded = self.embedding(inputs)
        if not begin_state:
            begin_state = self.begin_state(batch_size=inputs.shape[1])
        out_states = []
        encoded_raw = []
        encoded_dropped = []
        for i, (e, s) in enumerate(zip(self.encoder, begin_state)):
            encoded, state = e(encoded, s)
            encoded_raw.append(encoded)
            out_states.append(state)
            if self._drop_h and i != len(self.encoder) - 1:
                encoded = nd.Dropout(encoded, p=self._drop_h, axes=(0,))
                encoded_dropped.append(encoded)
        if self._dropout:
            encoded = nd.Dropout(encoded, p=self._dropout, axes=(0,))
        encoded_dropped.append(encoded)
        with autograd.predict_mode():
            out = self.decoder(encoded)
        return out, out_states, encoded_raw, encoded_dropped
    def forward(self,
                inputs,
                char_inputs,
                begin_state=None,
                valid_length=None,
                masks=None):  # pylint: disable=arguments-differ
        """Implement the forward computation that the awd language model and cache model use.

        Parameters
        -----------
        inputs : NDArray
            input tensor with shape `(sequence_length, batch_size)`
            when `layout` is "TNC".
        begin_state : list
            initial recurrent state tensor with length equals to num_layers.
            the initial state with shape `(1, batch_size, num_hidden)`

        Returns
        --------
        out: NDArray
            output tensor with shape `(sequence_length, batch_size, input_size)`
            when `layout` is "TNC".
        out_states: list
            output recurrent state tensor with length equals to num_layers.
            the state with shape `(1, batch_size, num_hidden)`
        encoded_raw: list
            The list of outputs of the model's encoder with length equals to num_layers.
            the shape of every encoder's output `(sequence_length, batch_size, num_hidden)`
        encoded_dropped: list
            The list of outputs with dropout of the model's encoder with length equals
            to num_layers. The shape of every encoder's dropped output
            `(sequence_length, batch_size, num_hidden)`
        """
        if self._use_pretrained_embedding:
            encoded = self.embedding(inputs, char_inputs)
        else:
            encoded = self.embedding(inputs)
        encoded, _, encoded_raw, masks = self.encoder(
            encoded, valid_length=valid_length, masks=masks)
        if self._use_encoder_last_variational_dropout:
            encoded = nd.Dropout(encoded, p=self._dropout, axes=(1, ))
        else:
            encoded = nd.Dropout(encoded, p=self._dropout)
        encoded = encoded.swapaxes(dim1=0, dim2=1)
        out = self.decoder(encoded)
        return out, _, encoded_raw, _, masks
예제 #10
0
 def forward(self, inputs, begin_state=None): # pylint: disable=arguments-differ
     """Defines the forward computation. Arguments can be either
     :py:class:`NDArray` or :py:class:`Symbol`."""
     encoded = self.embedding(inputs)
     if not begin_state:
         begin_state = self.begin_state(batch_size=inputs.shape[1])
     encoded, state = self.encoder(encoded, begin_state)
     if self._dropout:
         encoded = nd.Dropout(encoded, p=self._dropout, axes=(0,))
     out = self.decoder(encoded)
     return out, state
예제 #11
0
    def data(self, ctx=None):
        """Returns a copy of this parameter on one context. Must have been
        initialized on this context before.

        Parameters
        ----------
        ctx : Context
            Desired context.
        Returns
        -------
        NDArray on ctx
        """
        d = self._check_and_get(self._data, ctx)
        if self._rate:
            d = nd.Dropout(d, self._rate, self._mode, self._axes)
        return d
예제 #12
0
    def hybrid_forward(self, F, inputs, begin_state=None): # pylint: disable=arguments-differ
        """Defines the forward computation. Arguments can be either
        :py:class:`NDArray` or :py:class:`Symbol`.

        Parameters
        -----------
        inputs : NDArray
            input tensor with shape `(sequence_length, batch_size)`
              when `layout` is "TNC".
        begin_state : list
            initial recurrent state tensor with length equals to num_layers-1.
            the initial state with shape `(num_layers, batch_size, num_hidden)`

        Returns
        --------
        out: NDArray
            output tensor with shape `(sequence_length, batch_size, input_size)`
              when `layout` is "TNC".
        out_states: list
            output recurrent state tensor with length equals to num_layers-1.
            the state with shape `(num_layers, batch_size, num_hidden)`
        """
        # XXX Temporary hack for hybridization as hybridblock does not support None inputs
        if isinstance(begin_state, list) and len(begin_state) == 0:
            begin_state = None

        encoded = self.embedding(inputs)
        if not begin_state:
            if F == nd:
                begin_state = self.begin_state(batch_size=inputs.shape[1])
            else:
                begin_state = self.begin_state(batch_size=0, func=sym.zeros)
        encoded, state = self.encoder(encoded, begin_state)
        if self._dropout:
            encoded = nd.Dropout(encoded, p=self._dropout, axes=(0,))
        out = self.decoder(encoded)
        return out, state
예제 #13
0
 def dropout(x):
     return nd.Dropout(x)
예제 #14
0
    def forward(self,
                word_inputs,
                tag_inputs,
                arc_targets=None,
                rel_targets=None):
        """Run decoding

        Parameters
        ----------
        word_inputs : mxnet.ndarray.NDArray
            word indices of seq_len x batch_size
        tag_inputs : mxnet.ndarray.NDArray
            tag indices of seq_len x batch_size
        arc_targets : mxnet.ndarray.NDArray
            gold arc indices of seq_len x batch_size
        rel_targets : mxnet.ndarray.NDArray
            gold rel indices of seq_len x batch_size
        Returns
        -------
        tuple
            (arc_accuracy, rel_accuracy, overall_accuracy, loss) when training, else if given gold target
        then return arc_accuracy, rel_accuracy, overall_accuracy, outputs, otherwise return outputs, where outputs is a
        list of (arcs, rels).
        """
        is_train = autograd.is_training()

        def flatten_numpy(ndarray):
            """Flatten nd-array to 1-d column vector

            Parameters
            ----------
            ndarray : numpy.ndarray
                input tensor

            Returns
            -------
            numpy.ndarray
                A column vector

            """
            return np.reshape(ndarray, (-1, ), 'F')

        batch_size = word_inputs.shape[1]
        seq_len = word_inputs.shape[0]
        mask = np.greater(word_inputs, self._vocab.ROOT).astype(np.float32)
        num_tokens = int(np.sum(mask))  # non padding, non root token number

        if is_train or arc_targets is not None:
            mask_1D = flatten_numpy(mask)
            mask_1D_tensor = nd.array(mask_1D)

        unked_words = np.where(word_inputs < self._vocab.words_in_train,
                               word_inputs, self._vocab.UNK)
        word_embs = self.word_embs(nd.array(unked_words, dtype='int'))
        if self.pret_word_embs:
            word_embs = word_embs + self.pret_word_embs(nd.array(word_inputs))
        tag_embs = self.tag_embs(nd.array(tag_inputs))

        # Dropout
        emb_inputs = nd.concat(word_embs, tag_embs,
                               dim=2)  # seq_len x batch_size

        top_recur = biLSTM(
            self.f_lstm,
            self.b_lstm,
            emb_inputs,
            batch_size,
            dropout_x=self.dropout_lstm_input if is_train else 0)
        top_recur = nd.Dropout(data=top_recur, axes=[0], p=self.dropout_mlp)

        W_dep, b_dep = self.mlp_dep_W.data(), self.mlp_dep_b.data()
        W_head, b_head = self.mlp_head_W.data(), self.mlp_head_b.data()
        dep, head = leaky_relu(nd.dot(top_recur, W_dep.T) + b_dep), leaky_relu(
            nd.dot(top_recur, W_head.T) + b_head)
        dep, head = nd.Dropout(data=dep, axes=[0],
                               p=self.dropout_mlp), nd.Dropout(
                                   data=head, axes=[0], p=self.dropout_mlp)
        dep, head = nd.transpose(dep, axes=[2, 0,
                                            1]), nd.transpose(head,
                                                              axes=[2, 0, 1])
        dep_arc, dep_rel = dep[:self.mlp_arc_size], dep[self.mlp_arc_size:]
        head_arc, head_rel = head[:self.mlp_arc_size], head[self.mlp_arc_size:]

        W_arc = self.arc_W.data()
        arc_logits = bilinear(dep_arc,
                              W_arc,
                              head_arc,
                              self.mlp_arc_size,
                              seq_len,
                              batch_size,
                              num_outputs=1,
                              bias_x=True,
                              bias_y=False)
        # (#head x #dep) x batch_size

        flat_arc_logits = reshape_fortran(arc_logits,
                                          (seq_len, seq_len * batch_size))
        # (#head ) x (#dep x batch_size)

        arc_preds = arc_logits.argmax(0)
        # seq_len x batch_size

        if is_train or arc_targets is not None:
            correct = np.equal(arc_preds.asnumpy(), arc_targets)
            arc_correct = correct.astype(np.float32) * mask
            arc_accuracy = np.sum(arc_correct) / num_tokens
            targets_1D = flatten_numpy(arc_targets)
            losses = self.softmax_loss(flat_arc_logits, nd.array(targets_1D))
            arc_loss = nd.sum(losses * mask_1D_tensor) / num_tokens

        if not is_train:
            arc_probs = np.transpose(
                np.reshape(
                    nd.softmax(flat_arc_logits, axis=0).asnumpy(),
                    (seq_len, seq_len, batch_size), 'F'))
        # #batch_size x #dep x #head

        W_rel = self.rel_W.data()
        rel_logits = bilinear(dep_rel,
                              W_rel,
                              head_rel,
                              self.mlp_rel_size,
                              seq_len,
                              batch_size,
                              num_outputs=self._vocab.rel_size,
                              bias_x=True,
                              bias_y=True)
        # (#head x rel_size x #dep) x batch_size

        flat_rel_logits = reshape_fortran(
            rel_logits, (seq_len, self._vocab.rel_size, seq_len * batch_size))
        # (#head x rel_size) x (#dep x batch_size)

        _target_vec = nd.array(targets_1D if is_train else flatten_numpy(
            arc_preds.asnumpy())).reshape(seq_len * batch_size, 1)
        _target_mat = _target_vec * nd.ones((1, self._vocab.rel_size))

        partial_rel_logits = nd.pick(flat_rel_logits, _target_mat.T, axis=0)
        # (rel_size) x (#dep x batch_size)

        if is_train or arc_targets is not None:
            rel_preds = partial_rel_logits.argmax(0)
            targets_1D = flatten_numpy(rel_targets)
            rel_correct = np.equal(rel_preds.asnumpy(), targets_1D).astype(
                np.float32) * mask_1D
            rel_accuracy = np.sum(rel_correct) / num_tokens
            losses = self.softmax_loss(partial_rel_logits,
                                       nd.array(targets_1D))
            rel_loss = nd.sum(losses * mask_1D_tensor) / num_tokens

        if not is_train:
            rel_probs = np.transpose(
                np.reshape(
                    nd.softmax(flat_rel_logits.transpose([1, 0, 2]),
                               axis=0).asnumpy(),
                    (self._vocab.rel_size, seq_len, seq_len, batch_size), 'F'))
        # batch_size x #dep x #head x #nclasses

        if is_train or arc_targets is not None:
            loss = arc_loss + rel_loss
            correct = rel_correct * flatten_numpy(arc_correct)
            overall_accuracy = np.sum(correct) / num_tokens

        if is_train:
            return arc_accuracy, rel_accuracy, overall_accuracy, loss

        outputs = []

        for msk, arc_prob, rel_prob in zip(np.transpose(mask), arc_probs,
                                           rel_probs):
            # parse sentences one by one
            msk[0] = 1.
            sent_len = int(np.sum(msk))
            arc_pred = arc_mst(arc_prob, sent_len, msk)
            rel_prob = rel_prob[np.arange(len(arc_pred)), arc_pred]
            rel_pred = rel_argmax(rel_prob, sent_len)
            outputs.append((arc_pred[1:sent_len], rel_pred[1:sent_len]))

        if arc_targets is not None:
            return arc_accuracy, rel_accuracy, overall_accuracy, outputs
        return outputs
예제 #15
0
    def forward(self,
                inputs,
                begin_state=None,
                token_types=None,
                valid_length=None,
                masked_positions=None):  # pylint: disable=arguments-differ
        """Implement the forward computation that the awd language model and cache model use.

        Parameters
        -----------
        inputs : NDArray
            input tensor with shape `(sequence_length, batch_size)`
            when `layout` is "TNC".
        begin_state : list
            initial recurrent state tensor with length equals to num_layers.
            the initial state with shape `(1, batch_size, num_hidden)`
        token_types: NDArray
            input token type tensor, shape (batch_size, seq_length).
            If the inputs contain two sequences, then the token type of the first
            sequence differs from that of the second one.
        valid_length: NDArray
            optional tensor of input sequence valid lengths, shape (batch_size,)
        masked_positions: optional tensor of position of tokens for masked LM decoding,
            shape (batch_size, num_masked_positions).

        Returns
        --------
        out: NDArray
            output tensor with shape `(sequence_length, batch_size, input_size)`
            when `layout` is "TNC".
        out_states: list
            output recurrent state tensor with length equals to num_layers.
            the state with shape `(1, batch_size, num_hidden)`
        encoded_raw: list
            The list of outputs of the model's encoder with length equals to num_layers.
            the shape of every encoder's output `(sequence_length, batch_size, num_hidden)`
        encoded_dropped: list
            The list of outputs with dropout of the model's encoder with length equals
            to num_layers. The shape of every encoder's dropped output
            `(sequence_length, batch_size, num_hidden)`
        """
        batch_size = inputs.shape[1]
        inputs = nd.transpose(inputs, axes=(1, 0))
        if token_types is None:
            token_types = nd.zeros_like(inputs)
        encoded = self.embedding(inputs,
                                 token_types=token_types,
                                 valid_length=valid_length,
                                 masked_positions=masked_positions)
        encoded = nd.transpose(encoded, axes=(1, 0, 2))
        encoded = nd.Dropout(encoded, p=self._drop_i, axes=(0, ))
        if not begin_state:
            begin_state = self.begin_state(batch_size=batch_size)
        out_states = []
        encoded_raw = []
        encoded_dropped = []
        for i, (e, s) in enumerate(zip(self.encoder, begin_state)):
            encoded, state = e(encoded, s)
            encoded_raw.append(encoded)
            out_states.append(state)
            if i != len(self.encoder) - 1:
                encoded = nd.Dropout(encoded, p=self._drop_h, axes=(0, ))
                encoded_dropped.append(encoded)
        encoded = nd.Dropout(encoded, p=self._dropout, axes=(0, ))
        encoded_dropped.append(encoded)
        #use mos
        latent = nd.Dropout(self.latent(encoded), p=self._drop_l, axes=(0, ))
        logit = self.decoder(latent.reshape(-1, self._embed_size))

        prior_logit = self.prior(encoded).reshape(-1, self._num_experts)
        prior = nd.softmax(prior_logit, axis=-1)

        prob = nd.softmax(logit.reshape(-1, self._vocab_size), axis=-1)
        prob = prob.reshape(-1, self._num_experts, self._vocab_size)
        prob = (prob *
                prior.expand_dims(2).broadcast_to(prob.shape)).sum(axis=1)

        out = nd.log(nd.add(prob, 1e-8)).reshape(-1, batch_size,
                                                 self._vocab_size)

        return out, out_states, encoded_raw, encoded_dropped
예제 #16
0
def dropout2(X, drop_rate):
    autograd.set_training(True)
    Z = nd.zeros_like(X)
    nd.Dropout(X, p=drop_rate, out=Z)
    return Z