コード例 #1
0
ファイル: gnmt.py プロジェクト: tayyabjsr13/gluon-nlp
 def __init__(self,
              cell_type='lstm',
              attention_cell='scaled_luong',
              num_layers=2,
              hidden_size=128,
              dropout=0.0,
              use_residual=True,
              output_attention=False,
              i2h_weight_initializer=None,
              h2h_weight_initializer=None,
              i2h_bias_initializer='zeros',
              h2h_bias_initializer='zeros',
              prefix=None,
              params=None):
     super(GNMTDecoder, self).__init__(prefix=prefix, params=params)
     self._cell_type = _get_cell_type(cell_type)
     self._num_layers = num_layers
     self._hidden_size = hidden_size
     self._dropout = dropout
     self._use_residual = use_residual
     self._output_attention = output_attention
     with self.name_scope():
         self.attention_cell = _get_attention_cell(attention_cell,
                                                   units=hidden_size)
         self.dropout_layer = nn.Dropout(dropout)
         self.rnn_cells = nn.HybridSequential()
         for i in range(num_layers):
             self.rnn_cells.add(
                 self._cell_type(
                     hidden_size=self._hidden_size,
                     i2h_weight_initializer=i2h_weight_initializer,
                     h2h_weight_initializer=h2h_weight_initializer,
                     i2h_bias_initializer=i2h_bias_initializer,
                     h2h_bias_initializer=h2h_bias_initializer,
                     prefix='rnn%d_' % i))
コード例 #2
0
ファイル: gnmt.py プロジェクト: tayyabjsr13/gluon-nlp
 def __init__(self,
              cell_type='lstm',
              num_layers=2,
              num_bi_layers=1,
              hidden_size=128,
              dropout=0.0,
              use_residual=True,
              i2h_weight_initializer=None,
              h2h_weight_initializer=None,
              i2h_bias_initializer='zeros',
              h2h_bias_initializer='zeros',
              prefix=None,
              params=None):
     super(GNMTEncoder, self).__init__(prefix=prefix, params=params)
     self._cell_type = _get_cell_type(cell_type)
     assert num_bi_layers <= num_layers,\
         'Number of bidirectional layers must be smaller than the total number of layers, ' \
         'num_bi_layers={}, num_layers={}'.format(num_bi_layers, num_layers)
     self._num_bi_layers = num_bi_layers
     self._num_layers = num_layers
     self._hidden_size = hidden_size
     self._dropout = dropout
     self._use_residual = use_residual
     with self.name_scope():
         self.dropout_layer = nn.Dropout(dropout)
         self.rnn_cells = nn.HybridSequential()
         for i in range(num_layers):
             if i < num_bi_layers:
                 self.rnn_cells.add(
                     rnn.BidirectionalCell(
                         l_cell=self._cell_type(
                             hidden_size=self._hidden_size,
                             i2h_weight_initializer=i2h_weight_initializer,
                             h2h_weight_initializer=h2h_weight_initializer,
                             i2h_bias_initializer=i2h_bias_initializer,
                             h2h_bias_initializer=h2h_bias_initializer,
                             prefix='rnn%d_l_' % i),
                         r_cell=self._cell_type(
                             hidden_size=self._hidden_size,
                             i2h_weight_initializer=i2h_weight_initializer,
                             h2h_weight_initializer=h2h_weight_initializer,
                             i2h_bias_initializer=i2h_bias_initializer,
                             h2h_bias_initializer=h2h_bias_initializer,
                             prefix='rnn%d_r_' % i)))
             else:
                 self.rnn_cells.add(
                     self._cell_type(
                         hidden_size=self._hidden_size,
                         i2h_weight_initializer=i2h_weight_initializer,
                         h2h_weight_initializer=h2h_weight_initializer,
                         i2h_bias_initializer=i2h_bias_initializer,
                         h2h_bias_initializer=h2h_bias_initializer,
                         prefix='rnn%d_' % i))