示例#1
0
文件: RTHN.py 项目: Renren-yu/RTHN
def senEncode_softmax(s_senEncode, w_varible, b_varible, n_feature, doc_len):
    s = tf.reshape(s_senEncode, [-1, n_feature])
    s = tf.nn.dropout(s, keep_prob=FLAGS.keep_prob2)
    w = func.get_weight_varible(w_varible, [n_feature, FLAGS.n_class])
    b = func.get_weight_varible(b_varible, [FLAGS.n_class])
    pred = tf.matmul(s, w) + b
    pred *= func.getmask(doc_len, FLAGS.max_doc_len, [-1, 1])
    pred = tf.nn.softmax(pred)
    pred = tf.reshape(pred, [-1, FLAGS.max_doc_len, FLAGS.n_class])
    reg = tf.nn.l2_loss(w) + tf.nn.l2_loss(b)
    return pred, reg
示例#2
0
文件: ecjc.py 项目: LeMei/ecjd
def build_model(x,
                sen_len,
                doc_len,
                word_embedding,
                clause_position,
                embedding_pos,
                keep_prob1,
                keep_prob2,
                RNN=func.biLSTM):
    x = tf.nn.embedding_lookup(word_embedding, x)
    inputs = tf.reshape(x, [-1, FLAGS.max_sen_len, FLAGS.embedding_dim])
    n_hidden = 2 * FLAGS.n_hidden

    inputs = tf.nn.dropout(inputs, keep_prob=keep_prob1)
    sen_len = tf.reshape(sen_len, [-1])
    with tf.name_scope('word_encode'):
        wordEncode = RNN(inputs,
                         sen_len,
                         n_hidden=FLAGS.n_hidden,
                         scope=FLAGS.scope + 'word_layer')
    wordEncode = tf.reshape(wordEncode, [-1, FLAGS.max_sen_len, n_hidden])

    with tf.name_scope('attention'):
        w1 = func.get_weight_varible('word_att_w1', [n_hidden, n_hidden])
        b1 = func.get_weight_varible('word_att_b1', [n_hidden])
        w2 = func.get_weight_varible('word_att_w2', [n_hidden, 1])
        senEncode = func.att_var(wordEncode, sen_len, w1, b1, w2)
        # (32*75,200)
    senEncode = tf.reshape(senEncode,
                           [-1, FLAGS.max_doc_len, n_hidden])  #(32, 75, 200)

    n_feature = 2 * FLAGS.n_hidden
    out_units = 2 * FLAGS.n_hidden  #200
    batch = tf.shape(senEncode)[0]  #32
    pred_zeros = tf.zeros(([batch, FLAGS.max_doc_len,
                            FLAGS.max_doc_len]))  #(32,75,75)
    matrix = tf.reshape(
        (1 - tf.eye(FLAGS.max_doc_len)),
        [1, FLAGS.max_doc_len, FLAGS.max_doc_len]) + pred_zeros  # 构造单位矩阵
    pred_emotion_assist_list, reg_emotion_assist_list, pred_emotion_assist_label_list = [], [], []
    pred_cause_assist_list, reg_cause_assist_list, pred_cause_assist_label_list = [], [], []

    if FLAGS.assist_n_layers > 1:
        '''******* emotion layer 1******'''
        emotion_senEncode = trans_func(senEncode, senEncode, n_feature,
                                       out_units,
                                       'emotion_layer1')  #(32,75,200)
        pred_emotion_assist, reg_emotion_assist = senEncode_emotion_softmax(
            emotion_senEncode, 'softmax_assist_w1', 'softmax_assist_b1',
            out_units, doc_len)
        #(32, 75,2)
        pred_emotion_assist_label = tf.cast(
            tf.reshape(tf.argmax(pred_emotion_assist, axis=2),
                       [-1, 1, FLAGS.max_doc_len]), tf.float32)
        #(32, 75, 1)=>(32, 1, 75)

        pred_emotion_assist_position = tf.cast(
            tf.reshape(tf.argmax(pred_emotion_assist_label, axis=2), [-1, 1]) +
            1, tf.int32)  #emotion clause的所在位置,辅助clause的提取
        pred_clause_relative_position = tf.cast(
            tf.reshape(clause_position - pred_emotion_assist_position + 69,
                       [-1, FLAGS.max_doc_len]),
            tf.float32)  #基于emotion clause的相对位置 (32, 1, 75)
        pred_clause_relative_position *= func.getmask(doc_len,
                                                      FLAGS.max_doc_len,
                                                      [-1, FLAGS.max_doc_len])
        pred_clause_relative_position = tf.cast(pred_clause_relative_position,
                                                tf.int32)
        pred_clause_rep_embed = tf.nn.embedding_lookup(
            embedding_pos, pred_clause_relative_position)  #(32, 75, 50)

        pred_emotion_assist_label = (pred_emotion_assist_label +
                                     pred_zeros) * matrix  # 屏蔽预测为1的标签
        #matrix=>(32, 75, 75)
        #pred_assist_label=>(32, 75, 75)
        pred_emotion_assist_label_list.append(pred_emotion_assist_label)
        pred_emotion_assist_list.append(pred_emotion_assist)
        reg_emotion_assist_list.append(reg_emotion_assist)
        '''******* cause layer 1******'''
        cause_senEncode_assist = tf.concat([senEncode, pred_clause_rep_embed],
                                           axis=2)
        n_feature = out_units + FLAGS.embedding_dim_pos
        cause_senEncode = trans_func(cause_senEncode_assist, senEncode,
                                     n_feature, out_units, 'cause_layer')

        pred_cause_assist, reg_cause_assist = senEncode_cause_softmax(
            cause_senEncode, 'cause_softmax_assist_w1',
            'cause_softmax_assist_b1', out_units, doc_len)
        # (32, 75,2)
        pred_cause_assist_label = tf.cast(
            tf.reshape(tf.argmax(pred_cause_assist, axis=2),
                       [-1, 1, FLAGS.max_doc_len]), tf.float32)
        # (32, 75, 1)=>(32, 1, 75)
        pred_cause_assist_label = (pred_cause_assist_label +
                                   pred_zeros) * matrix  # 屏蔽预测为1的标签
        # matrix=>(32, 75, 75)
        # pred_assist_label=>(32, 75, 75)
        pred_cause_assist_label_list.append(pred_cause_assist_label)
        pred_cause_assist_list.append(pred_cause_assist)
        reg_cause_assist_list.append(reg_cause_assist)

    for i in range(2, FLAGS.assist_n_layers):
        emotion_senEncode_assist = tf.concat([
            emotion_senEncode, pred_emotion_assist_label,
            pred_cause_assist_label
        ],
                                             axis=2)  # (32, 75, 275)
        n_feature = out_units + 2 * FLAGS.max_doc_len  # 275
        emotion_senEncode = trans_func(emotion_senEncode_assist,
                                       emotion_senEncode, n_feature, out_units,
                                       'emotion_layer' + str(i))  # (32,75,200)

        pred_emotion_assist, reg_emotion_assist = senEncode_emotion_softmax(
            emotion_senEncode, 'emotion_softmax_assist_w' + str(i),
            'emotion_softmax_assist_b' + str(i), out_units, doc_len)
        pred_emotion_assist_label = tf.cast(
            tf.reshape(tf.argmax(pred_emotion_assist, axis=2),
                       [-1, 1, FLAGS.max_doc_len]), tf.float32)

        # pred_emotion_assist_position = tf.cast(tf.reshape(tf.argmax(pred_emotion_assist_label, axis=2), [-1, 1]),
        #                                        tf.float32) + 1  # emotion clause的所在位置,辅助clause的提取
        # pred_clause_relative_position = tf.reshape(clause_position - pred_emotion_assist_position,
        #                                            [-1, FLAGS.max_doc_len])  # 基于emotion clause的相对位置 (32, 1, 75)
        # pred_clause_relative_position *= func.getmask(doc_len, FLAGS.max_doc_len, [-1, FLAGS.max_doc_len])
        # pred_clause_rep_embed = tf.nn.embedding_lookup(embedding_pos, pred_clause_relative_position)  # (32, 75, 50)

        pred_emotion_assist_label = (pred_emotion_assist_label +
                                     pred_zeros) * matrix
        pred_emotion_assist_label_list.append(pred_emotion_assist_label)

        pred_emotion_assist_label = tf.reduce_sum(
            pred_emotion_assist_label_list, axis=0)
        # 不同layer加和 pred_assist_label=>(32,75,75)

        pred_emotion_assist_list.append(pred_emotion_assist)
        reg_emotion_assist_list.append(reg_emotion_assist)

        cause_senEncode_assist = tf.concat([
            cause_senEncode, pred_cause_assist_label, pred_emotion_assist_label
        ],
                                           axis=2)  #(32, 75, 275)
        n_feature = out_units + 2 * FLAGS.max_doc_len  #275
        cause_senEncode = trans_func(cause_senEncode_assist, cause_senEncode,
                                     n_feature, out_units,
                                     'cause_layer' + str(i))  #(32,75,200)

        pred_cause_assist, reg_cause_assist = senEncode_cause_softmax(
            cause_senEncode, 'cause_softmax_assist_w' + str(i),
            'cause_softmax_assist_b' + str(i), out_units, doc_len)
        pred_cause_assist_label = tf.cast(
            tf.reshape(tf.argmax(pred_cause_assist, axis=2),
                       [-1, 1, FLAGS.max_doc_len]), tf.float32)
        pred_cause_assist_label = (pred_cause_assist_label +
                                   pred_zeros) * matrix
        pred_cause_assist_label_list.append(pred_cause_assist_label)

        pred_cause_assist_label = tf.reduce_sum(pred_cause_assist_label_list,
                                                axis=0)
        #不同layer加和 pred_assist_label=>(32,75,75)

        pred_cause_assist_list.append(pred_cause_assist)
        reg_cause_assist_list.append(reg_cause_assist)
    '''*******Main******'''

    if FLAGS.main_n_layers > 1:
        senEncode_main = tf.concat([emotion_senEncode, cause_senEncode],
                                   axis=2)
        n_feature = 2 * out_units
        senEncode_main = trans_func(senEncode_main, senEncode_main, n_feature,
                                    out_units, 'main_layer1')
        senEncode_main = tf.concat([
            senEncode_main, pred_emotion_assist_label, pred_cause_assist_label
        ],
                                   axis=2)
        n_feature = out_units + 2 * FLAGS.max_doc_len
        senEncode_main = trans_func(senEncode_main, senEncode_main, n_feature,
                                    out_units, 'main_layer2')
    else:
        senEncode_main = tf.concat([emotion_senEncode, cause_senEncode],
                                   axis=2)
        n_feature = 2 * out_units
        senEncode_main = trans_func(senEncode_main, senEncode_main, n_feature,
                                    out_units, 'main_layer1')
    pred, reg = senEncode_main_softmax(senEncode_main, 'softmax_w',
                                       'softmax_b', out_units, doc_len)

    return pred, reg, pred_emotion_assist_list, reg_emotion_assist_list, pred_cause_assist_list, reg_cause_assist_list