Ejemplo n.º 1
0
for i in range(1000 * FLAGS.NUM_OF_EPOCH):
    # i = 0
    if i % Num_of_Iterlation == 1000:
        epoch += 1
        FLAGS.lr_value *= FLAGS.lr_decay

    ################################################################
    # Training batch OPTIMIZE
    ################################################################
    index = utils.sampler(LABEL_POS=TRAIN_LABEL_POS,
                          LABEL_NEG=TRAIN_LABEL_NEG,
                          BATCH_SIZE=FLAGS.BATCH_SIZE)
    batch_input, batch_label = utils.generate_batch_jaso(INDEX=index,
                                                         MODEL=model,
                                                         DOC=TRAIN_DOC,
                                                         LABEL=TRAIN_LABEL,
                                                         MAXLEN=FLAGS.MAXLEN,
                                                         SESS=sess,
                                                         ATTENTION=False)
    seq_length, _ = utils.length(batch_input)

    _ = sess.run(
        [model.optm],
        feed_dict={
            model.X: batch_input,
            model.Y: batch_label,
            model.LEARNING_RATE: FLAGS.lr_value,
            model.SEQ: seq_length,
            model.Dropout_Rate1: FLAGS.Dropout_Rate1,
            model.Dropout_Rate2: FLAGS.Dropout_Rate2,
            model.TRAIN_PH: True
def main():
    args = parse_args()
    if args is None:
        exit()

    # test file
    filename = args.inputdir + args.input

    # init model config
    #TRAIN_DOC, TRAIN_LABEL, TEST_DOC, TEST_LABEL, LABEL_IDX = data_load.digi_data_load()
    TRAIN_DOC, TRAIN_LABEL, TEST_DOC, TEST_LABEL, LABEL_IDX = data_load.testcase_shuffle_data_load(
    )
    #data_load.testcase_add_data_load()
    class_num = TRAIN_LABEL.shape[1]
    FLAGS.NUM_OF_CLASS = class_num
    JM = utils.lookup_JM(FLAGS.INPUT_WIDTH, FLAGS.INPUT_DEPTH)

    # Start Session
    sess = tf.Session()
    print("Session Ready!")
    model = MODEL(sess=sess, JM=JM, FLAGS=FLAGS)

    # Initialization
    sess.run(tf.global_variables_initializer())
    model.JM.init_table(sess)

    # Restore parameter
    saver = tf.train.Saver()
    saver.restore(sess, "./Saver/{}/{}.ckpt".format(FLAGS.WRITER,
                                                    FLAGS.WRITER))

    if args.printtest == 'True':
        index = np.array(range(0, len(TEST_DOC)))
        batch_input, batch_label = utils.generate_batch_jaso(
            INDEX=index,
            MODEL=model,
            DOC=TEST_DOC,
            LABEL=TEST_LABEL,
            MAXLEN=FLAGS.INPUT_WIDTH,
            SESS=sess)

        proba, ts_loss, ts_acc, ts_merged = sess.run(
            [model.y_proba, model.cross_entropy, model.accuracy, model.merge],
            feed_dict={
                model.X: batch_input,
                model.Y: batch_label,
                model.LEARNING_RATE: FLAGS.lr_value,
                model.TRAIN_PH: False
            })

        pred_idx = np.apply_along_axis(np.argmax, 1, proba)
        real_idx = np.apply_along_axis(np.argmax, 1, batch_label)

        pos_idx = np.where(np.equal(pred_idx, real_idx) == True)[0]
        neg_idx = np.where(np.equal(pred_idx, real_idx) == False)[0]

        print('[ TEST ]')
        desc = """
            size:{}, correct:{}, wrong:{}, acc:{}, f1_score:{}, ts_loss:{}, ts_acc:{}
        """.format(index.shape[0], pos_idx.shape[0], neg_idx.shape[0],
                   round(pos_idx.shape[0] / index.shape[0] * 100, 3),
                   round(f1_score(real_idx, pred_idx, average='weighted'), 4),
                   ts_loss, ts_acc)
        print(desc)
        for idx in pos_idx:
            print('Positive Case:\t', TEST_DOC[index[idx]], '\t->\t',
                  LABEL_IDX[np.argmax(proba[idx])],
                  '({0:.2f})\t'.format(round(max(proba[idx]), 3)),
                  LABEL_IDX[np.argmax(batch_label[idx])])

        for idx in neg_idx:

            print(
                'Negative Case:\t', TEST_DOC[index[idx]], '\t->\t',
                LABEL_IDX[np.argmax(proba[idx])],
                '({0:.2f})\t'.format(round(max(proba[idx]), 3)),
                LABEL_IDX[np.argmax(batch_label[idx])],
                '({0:.2f})\t'.format(proba[idx][np.argmax(batch_label[idx])]))

        print()

    if args.input == 'a.txt':
        testfile = open(filename)
        for line in testfile:
            line = line.rstrip('\n\r').lower()

            jaso_splitted = jmu.jaso_split([line], MAXLEN=FLAGS.INPUT_WIDTH)
            batch_input = sess.run(model.jaso_Onehot,
                                   {model.X_Onehot: jaso_splitted})
            y_proba = sess.run(model.y_proba,
                               feed_dict={
                                   model.X: batch_input,
                                   model.TRAIN_PH: False
                               })
            #print(batch_input.shape, y_proba.shape)
            label = LABEL_IDX[np.argmax(y_proba[0])]
            if round(max(y_proba[0]), 3) > 0.8:
                print(line, '\t->\t', label, round(max(y_proba[0]), 3))

    if args.input == 'tmp.txt':
        tmp_df = pd.read_csv(filename, sep='\t', header=None)
        tmp_df['class'] = tmp_df[1].apply(lambda x: json2intent(x))
        sentences = tmp_df[0].tolist()

        labels = []
        probas = []
        for sent in sentences:
            jaso_splitted = jmu.jaso_split([sent], MAXLEN=FLAGS.INPUT_WIDTH)
            batch_input = sess.run(model.jaso_Onehot,
                                   {model.X_Onehot: jaso_splitted})
            y_proba = sess.run(model.y_proba,
                               feed_dict={
                                   model.X: batch_input,
                                   model.TRAIN_PH: False
                               })
            # print(batch_input.shape, y_proba.shape)
            label = LABEL_IDX[np.argmax(y_proba[0])]
            labels.append(label)
            probas.append(round(max(y_proba[0]), 3))

        tmp_df['pred'] = labels
        tmp_df['proba'] = probas

        for index, row in tmp_df.iterrows():
            # if row['class'] is None:
            #     continue

            if row['class'] is not None:
                continue

            if round(row['proba'], 3) > 0.8:
                print(row[0], '\t->\t', row['pred'], round(row['proba'], 3),
                      '\t', row['class'])

    if args.input == 'tmp_not_under.txt':
        tmp_df = pd.read_csv(filename, sep='\t', header=None)
        tmp_df['class'] = tmp_df[1].apply(lambda x: json2intent(x))
        sentences = tmp_df[0].tolist()

        labels = []
        probas = []
        for sent in sentences:
            jaso_splitted = jmu.jaso_split([sent], MAXLEN=FLAGS.INPUT_WIDTH)
            batch_input = sess.run(model.jaso_Onehot,
                                   {model.X_Onehot: jaso_splitted})
            y_proba = sess.run(model.y_proba,
                               feed_dict={
                                   model.X: batch_input,
                                   model.TRAIN_PH: False
                               })
            # print(batch_input.shape, y_proba.shape)
            label = LABEL_IDX[np.argmax(y_proba[0])]
            labels.append(label)
            probas.append(round(max(y_proba[0]), 3))

        tmp_df['pred'] = labels
        tmp_df['proba'] = probas

        for index, row in tmp_df.iterrows():

            if round(row['proba'], 3) > 0.1:
                print(row[1])
                print(row[0], '\t->\t', row['pred'], round(row['proba'], 3),
                      '\t')
                print()
Ejemplo n.º 3
0
################################################################################
# 한 epoch 당 iteration 횟수
Num_of_Iterlation = np.shape(TRAIN_DOC)[0] // FLAGS.BATCH_SIZE
epoch = 0
for i in range(Num_of_Iterlation * FLAGS.NUM_OF_EPOCH):
    # i=0
    if i % Num_of_Iterlation == 0:
        epoch += 1
        FLAGS.lr_value *= FLAGS.lr_decay

    ################################################################
    # Training batch OPTIMIZE
    ################################################################
    #index = utils.sampler(LABEL_POS=TRAIN_LABEL_POS, LABEL_NEG=TRAIN_LABEL_NEG, BATCH_SIZE=FLAGS.BATCH_SIZE)
    index = np.random.choice(range(0, len(TRAIN_DOC)), FLAGS.BATCH_SIZE)
    batch_input, batch_label = utils.generate_batch_jaso(INDEX=index, MODEL=model, DOC=TRAIN_DOC,
                                                         LABEL=TRAIN_LABEL, MAXLEN=FLAGS.INPUT_WIDTH, SESS=sess)

    _ = sess.run([model.optm],
                 feed_dict={model.X: batch_input,
                            model.Y: batch_label,
                            model.LEARNING_RATE: FLAGS.lr_value,
                            model.TRAIN_PH: True})


    ################################################################
    # Calculate Train & Test Loss, Accuracy & Summary / Print
    ################################################################
    if i % 10 == 0:
        ################################################################
        # Train batch LOSS CHECK
        ################################################################