-
Notifications
You must be signed in to change notification settings - Fork 0
/
bilstm.py
32 lines (26 loc) · 1019 Bytes
/
bilstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding:utf-8
import tensorflow as tf
from tensorflow.python.ops import rnn, rnn_cell
## define lstm model and reture related features
# return n outputs of the n lstm cells
def BILSTM(x, hidden_size):
# biLSTM:
# 功能:添加bidirectional_lstm操作
# 参数:
# x: [batch, height, width] / [batch, step, embedding_size]
# hidden_size: lstm隐藏层节点个数
# 输出:
# output: [batch, height, 2*hidden_size] / [batch, step, 2*hidden_size]
# input transformation
input_x = tf.transpose(x, [1, 0, 2])
# input_x = tf.reshape(input_x, [-1, w])
# input_x = tf.split(0, h, input_x)
input_x = tf.unpack(input_x)
# define the forward and backward lstm cells
gru_fw_cell = rnn_cell.GRUCell(hidden_size)
gru_bw_cell = rnn_cell.GRUCell(hidden_size)
output, _, _ = rnn.bidirectional_rnn(gru_fw_cell, gru_bw_cell, input_x, dtype=tf.float32)
# output transformation to the original tensor type
output = tf.pack(output)
output = tf.transpose(output, [1, 0, 2])
return output