class RNN(Base): def __init__(self, **args): self.maxlen = args['maxlen'] self.num_hidden = args['num_hidden'] if 'num_hidden' in args else 256 self.num_layers = args['num_layers'] if 'num_layers' in args else 2 self.keep_prob = args['keep_prob'] self.batch_size = args['batch_size'] self.rnn_type = args['rnn_type'] self.num_output = args['num_output'] self.rnn_layer = RNNLayer(self.rnn_type, self.num_hidden, self.num_layers) self.placeholder = {} def __call__(self, embed, name='encoder', middle_flag=False, hidden_flag=False, features=None, reuse=tf.AUTO_REUSE, **kwargs): #middle_flag: if True return middle output for each time step #hidden_flag: if True return hidden state length_name = name + "_length" if features == None: self.placeholder[length_name] = tf.placeholder(dtype=tf.int32, shape=[None], name=length_name) else: self.placeholder[length_name] = features[length_name] self.initial_state = None with tf.variable_scope("rnn", reuse=reuse): #for gru,lstm outputs:[batch_size, max_time, num_hidden] #for bi_gru,bi_lstm outputs:[batch_size, max_time, num_hidden*2] outputs, _, state = self.rnn_layer( inputs=embed, seq_len=self.placeholder[length_name]) #flatten: outputs_shape = outputs.shape.as_list() if middle_flag: outputs = tf.reshape(outputs, [-1, outputs_shape[2]]) dense = tf.layers.dense(outputs, self.num_output, name='fc') #[batch_size, max_time, num_output] dense = tf.reshape(dense, [-1, outputs_shape[1], self.num_output]) else: outputs = tf.reshape(outputs, [-1, outputs_shape[1] * outputs_shape[2]]) #[batch_size, num_output] dense = tf.layers.dense(outputs, self.num_output, name='fc') #使用最后一个time的输出 #outputs = outputs[:, -1, :] if hidden_flag: return dense, state, self.rnn_layer.pb_nodes else: return dense def feed_dict(self, name='encoder', initial_state=None, **kwargs): feed_dict = {} for key in kwargs: length_name = name + "_length" feed_dict[self.placeholder[length_name]] = kwargs[key] #初始状态值传入 if initial_state != None: feed_dict.update(self.rnn_layer.feed_dict(initial_state)) return feed_dict def pb_feed_dict(self, graph, name='encoder', initial_state=None, **kwargs): feed_dict = {} for key in kwargs: length_name = name + "_length" key_node = graph.get_operation_by_name(length_name).outputs[0] feed_dict[key_node] = kwargs[key] #初始状态值传入 if initial_state != None: feed_dict.update(self.rnn_layer.feed_dict(initial_state, graph)) return feed_dict
class Seq2seq(EncoderBase): def __init__(self, **kwargs): super(Seq2seq, self).__init__(**kwargs) self.num_hidden = kwargs['num_hidden'] if 'num_hidden' in kwargs else 256 self.num_layers = kwargs['num_layers'] if 'num_layers' in kwargs else 2 self.rnn_type = kwargs['rnn_type'] self.rnn_encode_layer = RNNLayer(self.rnn_type, self.num_hidden, self.num_layers) self.rnn_decode_layer = RNNLayer(self.rnn_type, self.num_hidden, self.num_layers) self.placeholder = {} def __call__(self, net_encode, net_decode, name = 'seq2seq', features = None, reuse = tf.AUTO_REUSE, **kwargs): #def create_model(encode_seqs, decode_seqs, src_vocab_size, emb_dim, is_train=True, reuse=False): length_encode_name = name + "_encode_length" length_decode_name = name + "_decode_length" self.placeholder[length_encode_name] = tf.placeholder(dtype=tf.int32, shape=[None], name = length_encode_name) self.placeholder[length_decode_name] = tf.placeholder(dtype=tf.int32, shape=[None], name = length_decode_name) if features != None: self.features = copy.copy(self.placeholder) self.placeholder[length_encode_name] = features[length_encode_name] self.placeholder[length_decode_name] = features[length_decode_name] outputs, final_state_encode, final_state_encode_for_feed = self.rnn_encode_layer(inputs = net_encode, seq_len = self.placeholder[length_encode_name], name = 'encode') # TODO: 修复decoder依赖encoder无法单个预测问题 outputs, final_state_decode, final_state_decode_for_feed = self.rnn_decode_layer(inputs = net_decode, seq_len = self.placeholder[length_decode_name], initial_state = final_state_encode, name = 'decode') outputs_shape = outputs.shape.as_list() outputs = tf.reshape(outputs, [-1, outputs_shape[2]]) dense = tf.layers.dense(outputs, self.num_output, name='fc') #[batch_size, max_time, num_output] dense = tf.reshape(dense, [-1, outputs_shape[1], self.num_output]) return dense, final_state_encode_for_feed, \ final_state_decode_for_feed, self.rnn_decode_layer.pb_nodes def feed_dict(self, name = 'seq2seq', initial_state = None, **kwargs): feed_dict = {} for key in kwargs: length_encode_name = name + "_encode_length" length_decode_name = name + "_decode_length" if kwargs[key][0] != None: feed_dict[self.placeholder[length_encode_name]] = kwargs[key][0] if kwargs[key][1] != None: feed_dict[self.placeholder[length_decode_name]] = kwargs[key][1] if initial_state != None: feed_dict.update(self.rnn_decode_layer.feed_dict(initial_state)) return feed_dict def pb_feed_dict(self, graph, name = 'seq2seq', initial_state = None, **kwargs): feed_dict = {} for key in kwargs: length_encode_name = name + "_encode_length" length_decode_name = name + "_decode_length" key_node0 = graph.get_operation_by_name(length_encode_name).outputs[0] key_node1 = graph.get_operation_by_name(length_decode_name).outputs[0] if kwargs[key][0] != None: feed_dict[key_node0] = kwargs[key][0] if kwargs[key][1] != None: feed_dict[key_node1] = kwargs[key][1] if initial_state != None: feed_dict.update(self.rnn_decode_layer.feed_dict(initial_state, graph)) return feed_dict