def __init__(self): super().__init__( source_inputter=inputters.WordEmbedder(embedding_size=512), target_inputter=inputters.WordEmbedder(embedding_size=512), encoder=encoders.RNNEncoder( num_layers=2, num_units=512, residual_connections=False, dropout=0.3, cell_class=tf.keras.layers.LSTMCell, ), decoder=decoders.AttentionalRNNDecoder( num_layers=2, num_units=512, bridge_class=layers.CopyBridge, attention_mechanism_class=tfa.seq2seq.LuongAttention, attention_layer_activation=None, cell_class=tf.keras.layers.LSTMCell, dropout=0.3, residual_connections=False, ), )
def testAttentionalRNNDecoderFirstLayer(self): decoder = decoders.AttentionalRNNDecoder(2, 20, first_layer_attention=True) self._testDecoder(decoder)
def testAttentionalRNNDecoder(self): decoder = decoders.AttentionalRNNDecoder(2, 20) self._testDecoder(decoder)
def testAttentionalRNNDecoderTraining(self): decoder = decoders.AttentionalRNNDecoder(2, 20) self._testDecoderTraining(decoder, support_alignment_history=True)
def testAttentionalRNNDecoderWithDenseBridge(self): decoder = decoders.AttentionalRNNDecoder(2, 36, bridge=bridge.DenseBridge()) encoder_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(5), tf.nn.rnn_cell.LSTMCell(5)]) initial_state_fn = lambda batch_size, dtype: encoder_cell.zero_state(batch_size, dtype) self._testDecoder(decoder, initial_state_fn=initial_state_fn)