コード例 #1
0
ファイル: pick.py プロジェクト: SkyErnest/attribute_charge
def main(_):

    path = os.getcwd()
    father_path = os.path.dirname(path)
    checkpoint_dir = father_path + "/checkpoint/"
    lstm_log_dir = father_path + "/log/evaluation_charge_log/"
    attr_log_dir = father_path + "/log/evaluation_attr_log/"
    val_lstm_log_dir = father_path + "/log/validation_charge_log/"
    val_attr_log_dir = father_path + "/log/validation_attr_log/"

    create_dir([
        checkpoint_dir, lstm_log_dir, attr_log_dir, val_lstm_log_dir,
        val_attr_log_dir
    ])
    restore = True
    skiptrain = False
    valmatrix = False
    val_case = False
    mixandmatrix = False
    single_attr_log = False
    bs = 32
    perstep = 500
    eva_number = 0
    val_number = 0
    mixcouple = [69, 71]
    mixattr = [2, 3, 7]
    single_attr = [4, 9]

    word2id, word_embeddings, attr_table, x_train, y_train, y_attr_train, x_test, y_test, y_attr_test, x_val, y_val, y_attr_val, namehash, length_train, length_test, length_val = load_data_and_labels_fewshot(
    )
    id2word = {}
    for i in word2id:
        id2word[word2id[i]] = i
    batches = batch_iter(list(zip(x_train, y_train, y_attr_train)),
                         global_config.batch_size, global_config.num_epochs)
    lstm_config = model.lstm_Config()
    lstm_config.num_steps = len(x_train[0])
    lstm_config.hidden_size = len(word_embeddings[0])
    lstm_config.vocab_size = len(word_embeddings)
    lstm_config.num_classes = len(y_train[0])
    lstm_config.num_epochs = 20
    lstm_config.batch_size = bs
    lstm_config.num_epochs = 20

    lstm_eval_config = model.lstm_Config()
    lstm_eval_config.keep_prob = 1.0
    lstm_eval_config.num_steps = len(x_train[0])
    lstm_eval_config.hidden_size = len(word_embeddings[0])
    lstm_eval_config.vocab_size = len(word_embeddings)
    lstm_eval_config.num_classes = len(y_train[0])
    lstm_eval_config.batch_size = bs
    lstm_eval_config.num_epochs = 20

    zero_x = [0 for i in range(lstm_config.num_steps)]
    zero_y = [0 for i in range(lstm_config.num_classes)]

    lstm_count_tab = np.array([[0.0 for i in range(lstm_config.num_classes)]
                               for j in range(lstm_config.num_classes)])
    total_tab = np.array([0.0 for i in range(lstm_config.num_classes)])
    with tf.Graph().as_default():
        tf.set_random_seed(6324)
        tf_config = tf.ConfigProto()
        tf_config.gpu_options.allow_growth = True
        sess = tf.Session(config=tf_config)
        with sess.as_default():
            lstm_initializer = tf.contrib.layers.xavier_initializer()
            with tf.variable_scope("lstm_model",
                                   reuse=None,
                                   initializer=lstm_initializer):
                print('lstm step1')
                lstm_model = model.LSTM_MODEL(word_embeddings=word_embeddings,
                                              attr_table=attr_table,
                                              config=lstm_config)
                print('lstm step2')
                lstm_optimizer = tf.train.AdamOptimizer(lstm_config.lr)
                print('lstm step3')
                lstm_global_step = tf.Variable(0,
                                               name="lstm_global_step",
                                               trainable=False)
                lstm_train_op = lstm_optimizer.minimize(
                    lstm_model.total_loss, global_step=lstm_global_step)
                print('lstm step4')
            saver = tf.train.Saver()
            init_op = tf.initialize_all_variables()
            sess.run(init_op)
            best_macro_f1 = 0.0
            if restore:
                f_f1 = open(val_lstm_log_dir + 'best_macro_f1', 'r')
                f1s = f_f1.readlines()
                best_macro_f1 = float(
                    f1s[-1].strip().split(' ')[-1].strip('[').strip(']'))
                f_f1.close()
                ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    pass

            def lstm_train_step(x_batch, y_batch, y_attr_batch, length_batch):
                """
				A single training step
				"""
                feed_dict = {
                    lstm_model.input_x: x_batch,
                    lstm_model.input_length: length_batch,
                    lstm_model.input_y: y_batch,
                    lstm_model.unmapped_input_attr: y_attr_batch,
                    lstm_model.keep_prob: 0.5,
                }
                _, step, total_loss, lstm_loss, attr_loss = sess.run([
                    lstm_train_op, lstm_global_step, lstm_model.total_loss,
                    lstm_model.lstm_loss, lstm_model.total_attr_loss
                ], feed_dict)
                time_str = datetime.datetime.now().isoformat()
                if step % 50 == 0:
                    #print sc
                    print((
                        "{}: step {}, total loss {:g}, lstm_loss {:g}, attr_loss {:g}"
                        .format(time_str, step, total_loss, lstm_loss,
                                attr_loss)))
                return step

            def lstm_dev_step(x_batch,
                              y_batch,
                              y_attr_batch,
                              length_batch,
                              writer=None):
                """
				Evaluates model on a dev set
				"""
                feed_dict = {
                    lstm_model.input_x: x_batch,
                    lstm_model.input_length: length_batch,
                    lstm_model.input_y: y_batch,
                    lstm_model.unmapped_input_attr: y_attr_batch,
                    lstm_model.keep_prob: 1.0,
                }
                runlist = [lstm_model.predictions, lstm_model.ans]

                lstm_p, lstm_l = sess.run(runlist, feed_dict=feed_dict)

                return lstm_p, lstm_l

            # batches = batch_iter(list(zip(x_train,y_train,y_attr_train,length_train)),lstm_config.batch_size, lstm_config.num_epochs)

            print('Evaluation')
            if mixandmatrix:
                f_mix = open(lstm_log_dir + str(eva_number) + 'mixed.html',
                             'w')
                f_mix.write(
                    '<head><meta http-equiv="Content-Type" content="text/html; charset=utf-8"/></head>\n'
                )
            if single_attr_log:
                f_single_attr = open(
                    lstm_log_dir + str(eva_number) + 'attr.html', 'w')
                f_single_attr.write(
                    '<head><meta http-equiv="Content-Type" content="text/html; charset=utf-8"/></head>\n'
                )

            all_count = 0.0
            total_losses, lstm_losses, attr_losses = 0.0, 0.0, 0.0
            lstm_prc = PrecRecallCounter(lstm_config.num_classes, lstm_log_dir,
                                         'lstm', eva_number)
            attr_prc = PrecRecallCounter(
                [2 for temp in range(global_config.num_of_attr)], attr_log_dir,
                'attr', eva_number)
            lstm_matrix = [[0 for j in range(lstm_config.num_classes)]
                           for i in range(lstm_config.num_classes)]
            num = int(len(y_test) / float(lstm_eval_config.batch_size))
            print(num)
            picked_x_index, picked_out, picked_label = [], [], []
            for i in range(num):
                if i % 100 == 0:
                    print(i)
                begin = i * lstm_eval_config.batch_size
                end = (i + 1) * lstm_eval_config.batch_size
                y_batch_t = y_test[begin:end]
                x_batch_t = x_test[begin:end]
                y_attr_batch_t = y_attr_test[begin:end]
                length_batch = length_test[begin:end]

                lstm_p, lstm_l = lstm_dev_step(x_batch_t, y_batch_t,
                                               y_attr_batch_t, length_batch)
                diff = np.not_equal(lstm_p, lstm_l)
                picked_out += list(lstm_p[diff])
                picked_x_index += list(
                    np.where(diff)[0] + i * lstm_config.batch_size)
                picked_label += list(lstm_l[diff])
                print(picked_x_index, picked_out, picked_label)

            print(picked_x_index, picked_out, picked_label)
            with open('wrong_predictions.txt', 'w', encoding='utf-8') as f:
                for i in list(zip(picked_x_index, picked_out, picked_label)):
                    f.write(str(i) + '\n')
コード例 #2
0
ファイル: train.py プロジェクト: rtygbwwwerr/attribute_charge
def main(_):

    path = os.getcwd()
    father_path = os.path.dirname(path)
    checkpoint_dir = father_path + "/checkpoint/"
    lstm_log_dir = father_path + "/log/evaluation_charge_log/"
    attr_log_dir = father_path + "/log/evaluation_attr_log/"
    val_lstm_log_dir = father_path + "/log/validation_charge_log/"
    val_attr_log_dir = father_path + "/log/validation_attr_log/"

    create_dir([
        checkpoint_dir, lstm_log_dir, attr_log_dir, val_lstm_log_dir,
        val_attr_log_dir
    ])
    restore = False
    skiptrain = False
    valmatrix = False
    val_case = False
    mixandmatrix = False
    single_attr_log = False
    bs = 32
    perstep = 500
    eva_number = 0
    val_number = 0
    mixcouple = [69, 71]
    mixattr = [2, 3, 7]
    single_attr = [4, 9]

    word2id, word_embeddings, attr_table, x_train, y_train, y_attr_train, x_test, y_test, y_attr_test, x_val, y_val, y_attr_val, namehash, length_train, length_test, length_val = load_data_and_labels_fewshot(
    )
    id2word = {}
    for i in word2id:
        id2word[word2id[i]] = i
    batches = batch_iter(list(zip(x_train, y_train, y_attr_train)),
                         global_config.batch_size, global_config.num_epochs)
    lstm_config = model.lstm_Config()
    lstm_config.num_steps = len(x_train[0])
    lstm_config.hidden_size = len(word_embeddings[0])
    lstm_config.vocab_size = len(word_embeddings)
    lstm_config.num_classes = len(y_train[0])
    lstm_config.num_epochs = 20
    lstm_config.batch_size = bs
    lstm_config.num_epochs = 20

    lstm_eval_config = model.lstm_Config()
    lstm_eval_config.keep_prob = 1.0
    lstm_eval_config.num_steps = len(x_train[0])
    lstm_eval_config.hidden_size = len(word_embeddings[0])
    lstm_eval_config.vocab_size = len(word_embeddings)
    lstm_eval_config.num_classes = len(y_train[0])
    lstm_eval_config.batch_size = bs
    lstm_eval_config.num_epochs = 20

    zero_x = [0 for i in range(lstm_config.num_steps)]
    zero_y = [0 for i in range(lstm_config.num_classes)]

    lstm_count_tab = np.array([[0.0 for i in range(lstm_config.num_classes)]
                               for j in range(lstm_config.num_classes)])
    total_tab = np.array([0.0 for i in range(lstm_config.num_classes)])
    with tf.Graph().as_default():
        tf.set_random_seed(6324)
        tf_config = tf.ConfigProto()
        tf_config.gpu_options.allow_growth = True
        sess = tf.Session(config=tf_config)
        with sess.as_default():
            lstm_initializer = tf.contrib.layers.xavier_initializer()
            with tf.variable_scope("lstm_model",
                                   reuse=None,
                                   initializer=lstm_initializer):
                print('lstm step1')
                lstm_model = model.LSTM_MODEL(word_embeddings=word_embeddings,
                                              attr_table=attr_table,
                                              config=lstm_config)
                print('lstm step2')
                lstm_optimizer = tf.train.AdamOptimizer(lstm_config.lr)
                print('lstm step3')
                lstm_global_step = tf.Variable(0,
                                               name="lstm_global_step",
                                               trainable=False)
                lstm_train_op = lstm_optimizer.minimize(
                    lstm_model.total_loss, global_step=lstm_global_step)
                print('lstm step4')
            saver = tf.train.Saver()
            init_op = tf.initialize_all_variables()
            sess.run(init_op)
            best_macro_f1 = 0.0
            if restore:
                f_f1 = open(val_lstm_log_dir + 'best_macro_f1', 'r')
                f1s = f_f1.readlines()
                best_macro_f1 = float(
                    f1s[-1].strip().split(' ')[-1].strip('[').strip(']'))
                f_f1.close()
                ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    pass

            def lstm_train_step(x_batch, y_batch, y_attr_batch, length_batch):
                """
				A single training step
				"""
                feed_dict = {
                    lstm_model.input_x: x_batch,
                    lstm_model.input_length: length_batch,
                    lstm_model.input_y: y_batch,
                    lstm_model.unmapped_input_attr: y_attr_batch,
                    lstm_model.keep_prob: 0.5,
                }
                _, step, total_loss, lstm_loss, attr_loss = sess.run([
                    lstm_train_op, lstm_global_step, lstm_model.total_loss,
                    lstm_model.lstm_loss, lstm_model.total_attr_loss
                ], feed_dict)
                time_str = datetime.datetime.now().isoformat()
                if step % 50 == 0:
                    #print sc
                    print(
                        "{}: step {}, total loss {:g}, lstm_loss {:g}, attr_loss {:g}"
                        .format(time_str, step, total_loss, lstm_loss,
                                attr_loss))
                return step

            def lstm_dev_step(x_batch,
                              y_batch,
                              y_attr_batch,
                              length_batch,
                              writer=None):
                """
				Evaluates model on a dev set
				"""
                feed_dict = {
                    lstm_model.input_x: x_batch,
                    lstm_model.input_length: length_batch,
                    lstm_model.input_y: y_batch,
                    lstm_model.unmapped_input_attr: y_attr_batch,
                    lstm_model.keep_prob: 1.0,
                }
                runlist = [
                    lstm_model.predictions, lstm_model.attr_preds,
                    lstm_model.total_loss, lstm_model.lstm_loss,
                    lstm_model.total_attr_loss, lstm_model.attn_weights
                ]

                lstm_p, attr_p, t_loss, l_loss, a_loss, attn_weights = sess.run(
                    runlist, feed_dict=feed_dict)

                return lstm_p, attr_p, t_loss, l_loss, a_loss, attn_weights

            batches = batch_iter(
                list(zip(x_train, y_train, y_attr_train, length_train)),
                lstm_config.batch_size, lstm_config.num_epochs)

            for batch in batches:
                x_batch, y_batch, y_attr_batch, length_batch = zip(*batch)
                step = lstm_train_step(x_batch, y_batch, y_attr_batch,
                                       length_batch)
                if ((step % perstep) == 0) or (skiptrain):

                    print('Evaluation')
                    if mixandmatrix:
                        f_mix = open(
                            lstm_log_dir + str(eva_number) + 'mixed.html', 'w')
                        f_mix.write(
                            '<head><meta http-equiv="Content-Type" content="text/html; charset=utf-8"/></head>\n'
                        )
                    if single_attr_log:
                        f_single_attr = open(
                            lstm_log_dir + str(eva_number) + 'attr.html', 'w')
                        f_single_attr.write(
                            '<head><meta http-equiv="Content-Type" content="text/html; charset=utf-8"/></head>\n'
                        )

                    all_count = 0.0
                    total_losses, lstm_losses, attr_losses = 0.0, 0.0, 0.0
                    lstm_prc = PrecRecallCounter(lstm_config.num_classes,
                                                 lstm_log_dir, 'lstm',
                                                 eva_number)
                    attr_prc = PrecRecallCounter(
                        [2 for temp in range(global_config.num_of_attr)],
                        attr_log_dir, 'attr', eva_number)
                    lstm_matrix = [[0 for j in range(lstm_config.num_classes)]
                                   for i in range(lstm_config.num_classes)]
                    num = int(len(y_test) / float(lstm_eval_config.batch_size))
                    print(num)
                    for i in range(num):
                        if i % 100 == 0:
                            print(i)
                        begin = i * lstm_eval_config.batch_size
                        end = (i + 1) * lstm_eval_config.batch_size
                        y_batch_t = y_test[begin:end]
                        x_batch_t = x_test[begin:end]
                        y_attr_batch_t = y_attr_test[begin:end]
                        length_batch = length_test[begin:end]

                        lstm_p, attr_p, t_loss, l_loss, a_loss, attn_weights = lstm_dev_step(
                            x_batch_t, y_batch_t, y_attr_batch_t, length_batch)
                        total_losses += t_loss
                        lstm_losses += l_loss
                        attr_losses += a_loss
                        for j in range(lstm_eval_config.batch_size):
                            indexes = np.flatnonzero(y_batch_t[j])
                            lstm_prc.multicount(lstm_p[j], indexes)
                            for index in indexes:
                                lstm_matrix[index][lstm_p[j]] += 1
                            for k in range(global_config.num_of_attr):
                                attr_prc.count(attr_p[j][k],
                                               y_attr_batch_t[j][k], k)
                            if mixandmatrix:
                                mixed = ismixed(mixcouple, lstm_p[j], indexes)
                                if mixed:
                                    wordcolor = '<font style="background: rgba(255, 255, 0, %f)">%s</font>\n'
                                    f_mix.write('<p>' + str(lstm_p[j]) + ' ' +
                                                str(indexes) + '</p>\n')
                                    towrite = ''
                                    for k in range(global_config.num_of_attr):
                                        towrite = towrite + str(
                                            attr_p[j][k]) + ' '
                                    f_mix.write('<p>' + towrite + '</p>\n')
                                    towrite = ''
                                    for k in range(global_config.num_of_attr):
                                        towrite = towrite + str(
                                            y_attr_batch_t[j][k]) + ' '
                                    f_mix.write('<p>' + towrite + '</p>\n')
                                    for c in mixattr:
                                        f_mix.write(wordcolor % (0, str(c)))
                                        for w in range(len(x_batch_t[j])):
                                            if w == length_batch[j]:
                                                break
                                            f_mix.write(
                                                wordcolor %
                                                (attn_weights[j][c][w] /
                                                 np.max(attn_weights[j][c]),
                                                 id2word[x_batch_t[j][w]]))
                                        f_mix.write('<p>---</p>\n')
                            if single_attr_log:
                                for attr_index in single_attr:
                                    if (attr_p[j][attr_index] !=
                                            y_attr_batch_t[j][attr_index]) & (
                                                y_attr_batch_t[j][attr_index]
                                                != 2):
                                        wordcolor = '<font style="background: rgba(255, 255, 0, %f)">%s</font>\n'
                                        f_single_attr.write(
                                            '<p>' + str(indexes) +
                                            str(attr_index) + ' ' +
                                            str(attr_p[j][attr_index]) + ' ' +
                                            str(y_attr_batch_t[j]
                                                [attr_index]) + '</p>\n')
                                        for w in range(len(x_batch_t[j])):
                                            if w == length_batch[j]:
                                                break
                                            f_single_attr.write(
                                                wordcolor %
                                                (attn_weights[j][attr_index][w]
                                                 / np.max(attn_weights[j]
                                                          [attr_index]),
                                                 id2word[x_batch_t[j][w]]))
                                        f_single_attr.write('<p>---</p>\n')

                    begin = num * lstm_eval_config.batch_size
                    y_batch_t = y_test[begin:]
                    x_batch_t = x_test[begin:]
                    y_attr_batch_t = y_attr_test[begin:]
                    length_batch = length_test[begin:]
                    cl = len(y_batch_t)
                    for itemp in range(lstm_eval_config.batch_size - cl):
                        y_batch_t = np.append(y_batch_t, [y_batch_t[0]],
                                              axis=0)
                        x_batch_t = np.append(x_batch_t, [x_batch_t[0]],
                                              axis=0)
                        y_attr_batch_t = np.append(y_attr_batch_t,
                                                   [y_attr_batch_t[0]],
                                                   axis=0)
                        length_batch = np.append(length_batch,
                                                 [length_batch[0]],
                                                 axis=0)
                    lstm_p, attr_p, t_loss, l_loss, a_loss, attn_weights = lstm_dev_step(
                        x_batch_t, y_batch_t, y_attr_batch_t, length_batch)
                    total_losses += t_loss
                    lstm_losses += l_loss
                    attr_losses += a_loss
                    for jtemp in range(cl):
                        indexes = np.flatnonzero(y_batch_t[jtemp])
                        lstm_prc.multicount(lstm_p[jtemp], indexes)
                        for index in indexes:
                            lstm_matrix[index][lstm_p[jtemp]] += 1
                        for k in range(global_config.num_of_attr):
                            attr_prc.count(attr_p[jtemp][k],
                                           y_attr_batch_t[jtemp][k], k)

                    lstm_prc.compute()
                    attr_prc.compute()

                    lstm_prc.output()
                    attr_prc.output()

                    if (lstm_prc.macro_f1[0] > best_macro_f1) or skiptrain:
                        best_macro_f1 = lstm_prc.macro_f1[0]
                        f_f1 = open(val_lstm_log_dir + 'best_macro_f1', 'a+')
                        f_f1.write('eva:' + str(eva_number) + ' ' +
                                   str(best_macro_f1) + '\n')
                        f_f1.close()
                        print('Validation')
                        if not skiptrain:
                            saver.save(sess,
                                       checkpoint_dir + 'model.ckpt',
                                       global_step=step)
                        all_count = 0.0
                        total_losses, lstm_losses, attr_losses = 0.0, 0.0, 0.0
                        val_lstm_prc = PrecRecallCounter(
                            lstm_config.num_classes, val_lstm_log_dir, 'lstm',
                            val_number)
                        val_attr_prc = PrecRecallCounter(
                            [2 for temp in range(global_config.num_of_attr)],
                            val_attr_log_dir, 'attr', val_number)
                        val_lstm_matrix = [[
                            0 for j in range(lstm_config.num_classes)
                        ] for i in range(lstm_config.num_classes)]
                        num = int(
                            len(y_val) / float(lstm_eval_config.batch_size))
                        if val_case:
                            f_case = open(
                                val_lstm_log_dir + 'case' + str(val_number),
                                'w')
                        print(num)
                        for i in range(num):
                            if i % 100 == 0:
                                print(i)
                            begin = i * lstm_eval_config.batch_size
                            end = (i + 1) * lstm_eval_config.batch_size
                            y_batch_t = y_val[begin:end]
                            x_batch_t = x_val[begin:end]
                            y_attr_batch_t = y_attr_val[begin:end]
                            length_batch = length_val[begin:end]

                            lstm_p, attr_p, t_loss, l_loss, a_loss, attn_weights = lstm_dev_step(
                                x_batch_t, y_batch_t, y_attr_batch_t,
                                length_batch)
                            total_losses += t_loss
                            lstm_losses += l_loss
                            attr_losses += a_loss
                            for j in range(lstm_eval_config.batch_size):
                                indexes = np.flatnonzero(y_batch_t[j])
                                val_lstm_prc.multicount(lstm_p[j], indexes)
                                for index in indexes:
                                    val_lstm_matrix[index][lstm_p[j]] += 1
                                for k in range(global_config.num_of_attr):
                                    val_attr_prc.count(attr_p[j][k],
                                                       y_attr_batch_t[j][k], k)
                                if val_case:
                                    towrite = str(lstm_p[j]) + '\t' + str(
                                        indexes[0]) + '\t' + str(
                                            attr_p[j]) + '\t' + str(
                                                y_attr_batch_t[j]) + '\t'
                                    for w in range(len(x_batch_t[j])):
                                        if w == length_batch[j]:
                                            break
                                        towrite = towrite + id2word[
                                            x_batch_t[j][w]] + ' '
                                    for temp_attr in range(
                                            global_config.num_of_attr):
                                        towrite = towrite + '\t'
                                        for w in range(len(x_batch_t[j])):
                                            if w == length_batch[j]:
                                                break
                                            towrite = towrite + str(
                                                attn_weights[j][temp_attr][w] /
                                                np.max(attn_weights[j]
                                                       [temp_attr])) + ' '
                                    towrite = towrite + '\n'
                                    f_case.write(towrite)
                        begin = num * lstm_eval_config.batch_size
                        y_batch_t = y_val[begin:]
                        x_batch_t = x_val[begin:]
                        y_attr_batch_t = y_attr_val[begin:]
                        length_batch = length_val[begin:]
                        cl = len(y_batch_t)
                        for itemp in range(lstm_eval_config.batch_size - cl):
                            y_batch_t = np.append(y_batch_t, [y_batch_t[0]],
                                                  axis=0)
                            x_batch_t = np.append(x_batch_t, [x_batch_t[0]],
                                                  axis=0)
                            y_attr_batch_t = np.append(y_attr_batch_t,
                                                       [y_attr_batch_t[0]],
                                                       axis=0)
                            length_batch = np.append(length_batch,
                                                     [length_batch[0]],
                                                     axis=0)
                        lstm_p, attr_p, t_loss, l_loss, a_loss, attn_weights = lstm_dev_step(
                            x_batch_t, y_batch_t, y_attr_batch_t, length_batch)
                        total_losses += t_loss
                        lstm_losses += l_loss
                        attr_losses += a_loss
                        for jtemp in range(cl):
                            indexes = np.flatnonzero(y_batch_t[jtemp])
                            val_lstm_prc.multicount(lstm_p[jtemp], indexes)
                            for index in indexes:
                                val_lstm_matrix[index][lstm_p[jtemp]] += 1
                            for k in range(global_config.num_of_attr):
                                val_attr_prc.count(attr_p[jtemp][k],
                                                   y_attr_batch_t[jtemp][k], k)
                            if val_case:
                                towrite = str(lstm_p[jtemp]) + '\t' + str(
                                    indexes[0]) + '\t' + str(
                                        attr_p[jtemp]) + '\t' + str(
                                            y_attr_batch_t[jtemp]) + '\t'
                                for w in range(len(x_batch_t[jtemp])):
                                    if w == length_batch[jtemp]:
                                        break
                                    towrite = towrite + id2word[
                                        x_batch_t[jtemp][w]] + ' '
                                for temp_attr in range(
                                        global_config.num_of_attr):
                                    towrite = towrite + '\t'
                                    for w in range(len(x_batch_t[jtemp])):
                                        if w == length_batch[jtemp]:
                                            break
                                        towrite = towrite + str(
                                            attn_weights[jtemp][temp_attr][w] /
                                            np.max(attn_weights[jtemp]
                                                   [temp_attr])) + ' '
                                towrite = towrite + '\n'
                                f_case.write(towrite)

                        val_lstm_prc.compute()
                        val_attr_prc.compute()
                        val_lstm_prc.output()
                        val_attr_prc.output()
                        if valmatrix:
                            fm = open(
                                val_lstm_log_dir + str(val_number) + 'matrix',
                                'w')
                            for i in range(lstm_config.num_classes):
                                towrite = ""
                                for j in range(lstm_config.num_classes):
                                    towrite = towrite + str(
                                        val_lstm_matrix[i][j]) + ' '
                                towrite = towrite + '\n'
                                fm.write(towrite)
                            fm.close()
                        val_number += 1

                    if mixandmatrix:
                        fm = open(lstm_log_dir + str(eva_number) + 'matrix',
                                  'w')
                        for i in range(lstm_config.num_classes):
                            towrite = ""
                            for j in range(lstm_config.num_classes):
                                towrite = towrite + str(
                                    lstm_matrix[i][j]) + ' '
                            towrite = towrite + '\n'
                            fm.write(towrite)
                        fm.close()

                    num = float(num)
                    tn = datetime.datetime.now()
                    print(tn.isoformat())
                    print('loss total:{:g}, lstm:{:g}, attr:{:g}'.format(
                        total_losses / num, lstm_losses / num,
                        attr_losses / num))
                    if skiptrain:
                        break
                    eva_number += 1
コード例 #3
0
def main(_):
    path = os.getcwd()
    father_path = os.path.dirname(path)
    checkpoint_dir = father_path + "/checkpoint/"
    lstm_log_dir = father_path + "/log/evaluation_charge_log/"
    attr_log_dir = father_path + "/log/evaluation_attr_log/"
    val_lstm_log_dir = father_path + "/log/validation_charge_log/"
    val_attr_log_dir = father_path + "/log/validation_attr_log/"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if not os.path.exists(lstm_log_dir):
        os.makedirs(lstm_log_dir)
    if not os.path.exists(attr_log_dir):
        os.makedirs(attr_log_dir)
    if not os.path.exists(val_lstm_log_dir):
        os.makedirs(val_lstm_log_dir)
    if not os.path.exists(val_attr_log_dir):
        os.makedirs(val_attr_log_dir)
    restore = False
    skiptrain = False
    bs = 32  #batch size
    perstep = 500
    eva_number = 0
    val_number = 0
    #init
    print "loading word embedding and data..."
    word2id, word_embeddings, attr_table, x_train, y_train, y_attr_train, x_test, y_test, y_attr_test, x_val, y_val, y_attr_val, namehash, length_train, length_test, length_val = load_data_and_labels_fewshot(
    )
    id2word = {}
    for i in word2id:
        id2word[word2id[i]] = i
    batches = batch_iter(list(zip(x_train, y_train, y_attr_train)),
                         global_config.batch_size, global_config.num_epochs)
    lstm_config = model.lstm_Config()
    lstm_config.num_steps = len(x_train[0])
    lstm_config.hidden_size = len(word_embeddings[0])
    lstm_config.vocab_size = len(word_embeddings)
    lstm_config.num_classes = len(y_train[0])
    lstm_config.num_epochs = 20
    lstm_config.batch_size = bs

    with tf.Graph().as_default():
        tf.set_random_seed(6324)
        tf_config = tf.ConfigProto()
        tf_config.gpu_options.allow_growth = True
        sess = tf.Session(config=tf_config)
        print "initializing model"
        with sess.as_default():
            lstm_initializer = tf.contrib.layers.xavier_initializer()
            with tf.variable_scope("lstm_model",
                                   reuse=None,
                                   initializer=lstm_initializer):
                lstm_model = model.LSTM_MODEL(word_embeddings=word_embeddings,
                                              attr_table=attr_table,
                                              config=lstm_config)
                lstm_optimizer = tf.train.AdamOptimizer(lstm_config.lr)
                lstm_global_step = tf.Variable(0,
                                               name="lstm_global_step",
                                               trainable=False)
                lstm_train_op = lstm_optimizer.minimize(
                    lstm_model.total_loss, global_step=lstm_global_step)
            saver = tf.train.Saver()
            init_op = tf.initialize_all_variables()
            sess.run(init_op)
            best_macro_f1 = 0.0
            if restore:
                f_f1 = open(val_lstm_log_dir + 'best_macro_f1', 'r')
                f1s = f_f1.readlines()
                best_macro_f1 = float(
                    f1s[-1].strip().split(' ')[-1].strip('[').strip(']'))
                f_f1.close()
                ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    pass
            print "initialized"

            batches = batch_iter(
                list(zip(x_train, y_train, y_attr_train, length_train)),
                lstm_config.batch_size, lstm_config.num_epochs)

            for batch in batches:
                x_batch, y_batch, y_attr_batch, length_batch = zip(*batch)
                step = lstm_train_step(lstm_train_op, lstm_global_step,
                                       lstm_model, sess, x_batch, y_batch,
                                       y_attr_batch, length_batch)
                if ((step % perstep) == 0) or (skiptrain):
                    new_marco_f1 = evaluation(eva_number, lstm_model, sess,
                                              lstm_config, lstm_log_dir,
                                              attr_log_dir, y_test, x_test,
                                              y_attr_test, length_test)
                    eva_number += 1
                    #when model get the best performance on test set, validate it on the validation set
                    if (new_marco_f1 > best_macro_f1) or skiptrain:
                        if not skiptrain:
                            saver.save(sess,
                                       checkpoint_dir + 'model.ckpt',
                                       global_step=step)
                        best_macro_f1 = new_marco_f1
                        validation(val_number, lstm_model, sess, lstm_config,
                                   val_lstm_log_dir, val_attr_log_dir, y_val,
                                   x_val, y_attr_val, length_val)
                        val_number += 1