Exemplo n.º 1
0
def _get_rnn_cell(mode, num_layers, input_size, hidden_size, dropout,
                  weight_dropout, var_drop_in, var_drop_state, var_drop_out):
    """create rnn cell given specs"""
    rnn_cell = rnn.SequentialRNNCell()
    with rnn_cell.name_scope():
        for i in range(num_layers):
            if mode == 'rnn_relu':
                cell = rnn.RNNCell(hidden_size, 'relu', input_size=input_size)
            elif mode == 'rnn_tanh':
                cell = rnn.RNNCell(hidden_size, 'tanh', input_size=input_size)
            elif mode == 'lstm':
                cell = rnn.LSTMCell(hidden_size, input_size=input_size)
            elif mode == 'gru':
                cell = rnn.GRUCell(hidden_size, input_size=input_size)
            if var_drop_in + var_drop_state + var_drop_out != 0:
                cell = contrib.rnn.VariationalDropoutCell(
                    cell, var_drop_in, var_drop_state, var_drop_out)

            rnn_cell.add(cell)
            if i != num_layers - 1 and dropout != 0:
                rnn_cell.add(rnn.DropoutCell(dropout))

            if weight_dropout:
                apply_weight_drop(rnn_cell, 'h2h_weight', rate=weight_dropout)

    return rnn_cell
Exemplo n.º 2
0
 def __init__(self, voc, **kwargs):
     super(Model, self).__init__(**kwargs)
     with self.name_scope():
         self.conv1 = nn.Conv2D(channels=n_nodes_cnn,
                                kernel_size=5,
                                strides=2)
         self.conv2 = nn.Conv2D(channels=n_nodes_cnn,
                                kernel_size=5,
                                strides=2)
         self.pool1 = nn.MaxPool2D(pool_size=2)
         self.pool2 = nn.MaxPool2D(pool_size=2)
         self.norm1 = nn.BatchNorm()
         self.norm2 = nn.BatchNorm()
         self.dense1 = nn.Dense(n_nodes_hidden)
         self.dense2 = nn.Dense(voc)
         self.encoder = rnn.SequentialRNNCell()
         with self.encoder.name_scope():
             self.encoder.add(rnn.LSTMCell(n_nodes_rnn))
         self.decoder = rnn.SequentialRNNCell()
         with self.decoder.name_scope():
             self.decoder.add(rnn.LSTMCell(n_nodes_rnn))