示例#1
0
    def __init__(self, rnn_type, bidirectional, num_layers, hidden_size,
                 dropout, embeddings, z_size):
        super(VaeEncoder, self).__init__()

        num_directions = 2 if bidirectional else 1
        assert hidden_size % num_directions == 0
        hidden_size = hidden_size // num_directions
        self.embeddings = embeddings
        self.no_pack_padded_seq = False
        self.varcoeff = 0.0
        self.varstep = 0.1

        if rnn_type == "SRU":
            # SRU doesn't support PackedSequence.
            self.no_pack_padded_seq = True
            self.rnn = onmt.SRU(input_size=embeddings.embedding_size,
                                hidden_size=hidden_size,
                                num_layers=num_layers,
                                dropout=dropout,
                                bidirectional=bidirectional)
        else:
            self.rnn = getattr(nn,
                               rnn_type)(input_size=embeddings.embedding_size,
                                         hidden_size=hidden_size,
                                         num_layers=num_layers,
                                         dropout=dropout,
                                         bidirectional=bidirectional)

        self.h2z = nn.Linear(hidden_size, z_size * 2)
示例#2
0
    def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers,
                   dropout):
        """
        Private helper for building standard decoder RNN.
        """
        # Use pytorch version when available.
        if rnn_type == "SRU":
            return onmt.SRU(input_size,
                            hidden_size,
                            num_layers=num_layers,
                            dropout=dropout)

        return getattr(nn, rnn_type)(input_size,
                                     hidden_size,
                                     num_layers=num_layers,
                                     dropout=dropout)