def testAttentionalRNNDecoderWithDenseBridge(self): decoder = decoders.AttentionalRNNDecoder(2, 36, bridge=bridge.DenseBridge()) encoder_cell = tf.nn.rnn_cell.MultiRNNCell( [tf.nn.rnn_cell.LSTMCell(5), tf.nn.rnn_cell.LSTMCell(5)]) initial_state_fn = lambda batch_size, dtype: encoder_cell.zero_state( batch_size, dtype) self._testDecoder(decoder, initial_state_fn=initial_state_fn)
def testDenseBridge(self): encoder_state = _build_state(3, 20, 6) decoder_state = _build_state(4, 30, 6) state = bridge.DenseBridge()(encoder_state, decoder_state) bridge.assert_state_is_compatible(decoder_state, state)