Exemple #1
0
class RNN(Base):
    def __init__(self, **args):
        self.maxlen = args['maxlen']
        self.num_hidden = args['num_hidden'] if 'num_hidden' in args else 256
        self.num_layers = args['num_layers'] if 'num_layers' in args else 2
        self.keep_prob = args['keep_prob']
        self.batch_size = args['batch_size']
        self.rnn_type = args['rnn_type']
        self.num_output = args['num_output']
        self.rnn_layer = RNNLayer(self.rnn_type, self.num_hidden,
                                  self.num_layers)
        self.placeholder = {}

    def __call__(self,
                 embed,
                 name='encoder',
                 middle_flag=False,
                 hidden_flag=False,
                 features=None,
                 reuse=tf.AUTO_REUSE,
                 **kwargs):
        #middle_flag: if True return middle output for each time step
        #hidden_flag: if True return hidden state
        length_name = name + "_length"
        if features == None:
            self.placeholder[length_name] = tf.placeholder(dtype=tf.int32,
                                                           shape=[None],
                                                           name=length_name)
        else:
            self.placeholder[length_name] = features[length_name]

        self.initial_state = None
        with tf.variable_scope("rnn", reuse=reuse):
            #for gru,lstm outputs:[batch_size, max_time, num_hidden]
            #for bi_gru,bi_lstm outputs:[batch_size, max_time, num_hidden*2]

            outputs, _, state = self.rnn_layer(
                inputs=embed, seq_len=self.placeholder[length_name])
            #flatten:
            outputs_shape = outputs.shape.as_list()
            if middle_flag:
                outputs = tf.reshape(outputs, [-1, outputs_shape[2]])
                dense = tf.layers.dense(outputs, self.num_output, name='fc')
                #[batch_size, max_time, num_output]
                dense = tf.reshape(dense,
                                   [-1, outputs_shape[1], self.num_output])
            else:
                outputs = tf.reshape(outputs,
                                     [-1, outputs_shape[1] * outputs_shape[2]])
                #[batch_size, num_output]
                dense = tf.layers.dense(outputs, self.num_output, name='fc')
            #使用最后一个time的输出
            #outputs = outputs[:, -1, :]
            if hidden_flag:
                return dense, state, self.rnn_layer.pb_nodes
            else:
                return dense

    def feed_dict(self, name='encoder', initial_state=None, **kwargs):
        feed_dict = {}
        for key in kwargs:
            length_name = name + "_length"
            feed_dict[self.placeholder[length_name]] = kwargs[key]
        #初始状态值传入
        if initial_state != None:
            feed_dict.update(self.rnn_layer.feed_dict(initial_state))
        return feed_dict

    def pb_feed_dict(self,
                     graph,
                     name='encoder',
                     initial_state=None,
                     **kwargs):
        feed_dict = {}
        for key in kwargs:
            length_name = name + "_length"
            key_node = graph.get_operation_by_name(length_name).outputs[0]
            feed_dict[key_node] = kwargs[key]
        #初始状态值传入
        if initial_state != None:
            feed_dict.update(self.rnn_layer.feed_dict(initial_state, graph))
        return feed_dict
Exemple #2
0
class Seq2seq(EncoderBase):
    def __init__(self, **kwargs):
        super(Seq2seq, self).__init__(**kwargs)
        self.num_hidden = kwargs['num_hidden'] if 'num_hidden' in kwargs else 256
        self.num_layers = kwargs['num_layers'] if 'num_layers' in kwargs else 2
        self.rnn_type = kwargs['rnn_type']
        self.rnn_encode_layer = RNNLayer(self.rnn_type, self.num_hidden, self.num_layers)
        self.rnn_decode_layer = RNNLayer(self.rnn_type, self.num_hidden, self.num_layers)
        self.placeholder = {}

    def __call__(self, net_encode, net_decode, name = 'seq2seq', 
                 features = None,  reuse = tf.AUTO_REUSE, **kwargs):
        #def create_model(encode_seqs, decode_seqs, src_vocab_size, emb_dim, is_train=True, reuse=False):
        length_encode_name = name + "_encode_length" 
        length_decode_name = name + "_decode_length" 
        self.placeholder[length_encode_name] = tf.placeholder(dtype=tf.int32, 
                                                    shape=[None], 
                                                    name = length_encode_name)
        self.placeholder[length_decode_name] = tf.placeholder(dtype=tf.int32, 
                                                    shape=[None], 
                                                    name = length_decode_name)
        if features != None:
            self.features = copy.copy(self.placeholder)
            self.placeholder[length_encode_name] = features[length_encode_name]
            self.placeholder[length_decode_name] = features[length_decode_name]

        outputs, final_state_encode, final_state_encode_for_feed = self.rnn_encode_layer(inputs = net_encode, 
                                                     seq_len = self.placeholder[length_encode_name],
                                                     name = 'encode')
        # TODO: 修复decoder依赖encoder无法单个预测问题
        outputs, final_state_decode, final_state_decode_for_feed  = self.rnn_decode_layer(inputs = net_decode, 
                                                     seq_len = self.placeholder[length_decode_name], 
                                                     initial_state = final_state_encode,
                                                     name = 'decode')
        outputs_shape = outputs.shape.as_list()
        outputs = tf.reshape(outputs, [-1, outputs_shape[2]])
        dense = tf.layers.dense(outputs, self.num_output, name='fc')
        #[batch_size, max_time, num_output]
        dense = tf.reshape(dense, [-1, outputs_shape[1], self.num_output])

        return dense, final_state_encode_for_feed, \
            final_state_decode_for_feed, self.rnn_decode_layer.pb_nodes

    def feed_dict(self, name = 'seq2seq', initial_state = None, **kwargs):
        feed_dict = {}
        for key in kwargs:
            length_encode_name = name + "_encode_length" 
            length_decode_name = name + "_decode_length" 
            if kwargs[key][0] != None:
                feed_dict[self.placeholder[length_encode_name]] = kwargs[key][0]
            if kwargs[key][1] != None:
                feed_dict[self.placeholder[length_decode_name]] = kwargs[key][1]
        if initial_state != None:
            feed_dict.update(self.rnn_decode_layer.feed_dict(initial_state))
        return feed_dict

    def pb_feed_dict(self, graph, name = 'seq2seq', initial_state = None, **kwargs):
        feed_dict = {}
        for key in kwargs:
            length_encode_name = name + "_encode_length" 
            length_decode_name = name + "_decode_length" 
            key_node0 = graph.get_operation_by_name(length_encode_name).outputs[0]
            key_node1 = graph.get_operation_by_name(length_decode_name).outputs[0]
            if kwargs[key][0] != None:
                feed_dict[key_node0] = kwargs[key][0]
            if kwargs[key][1] != None:
                feed_dict[key_node1] = kwargs[key][1]
        if initial_state != None:
            feed_dict.update(self.rnn_decode_layer.feed_dict(initial_state, graph))
        return feed_dict