コード例 #1
0
ファイル: predict.py プロジェクト: wfgreg/tbcnn
def predict_model(args, logdir, infile, embedfile):
    """Test a classifier to label ASTs"""

    with open(infile, 'rb') as fh:
        _, trees, cv, labels = pickle.load(fh)

    with open(embedfile, 'rb') as fh:
        embeddings, embed_lookup = pickle.load(fh)
        num_feats = len(embeddings[0])

    # build the inputs and outputs of the network
    nodes_node, children_node, hidden_node = network.init_net(
        num_feats, len(labels))
    out_node = network.out_layer(hidden_node)

    ### init the graph
    sess = tf.Session()  #config=tf.ConfigProto(device_count={'GPU':0}))
    sess.run(tf.global_variables_initializer())

    with tf.name_scope('saver'):
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(logdir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise 'Checkpoint not found.'

    correct_labels = []
    # make predicitons from the input
    predictions = []
    step = 0
    for batch in sampling.batch_samples(
            sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1):
        nodes, children, batch_labels = batch
        output = sess.run([out_node],
                          feed_dict={
                              nodes_node: nodes,
                              children_node: children,
                          })
        correct_labels.append(np.argmax(batch_labels))
        predictions.append(np.argmax(output))
        #        print(step, '/', len(trees))

        orig_num = np.argmax(batch_labels)
        orig_label = labels[orig_num] + '   \t'
        orig_label = ""

        pred_num = np.argmax(output)
        pred_label = labels[pred_num]
        pred_score = output[0][0][pred_num]
        pred_item = str(step + 1) + '/' + str(len(trees))
        if 'meta' in trees[step].keys() and 'name' in trees[step]['meta'].keys(
        ):
            pred_item += "   " + trees[step]['meta']['name']
        print(pred_label + '   \t' + str(pred_score) + ' \t' + orig_label +
              pred_item)

        step += 1
コード例 #2
0
def test_model(logdir, infile, embedfile):
    """Test a classifier to label ASTs"""

    with open(infile, 'rb') as fh:
        _, trees, labels = pickle.load(fh)

    with open(embedfile, 'rb') as fh:
        embeddings, embed_lookup = pickle.load(fh)
        num_feats = len(embeddings[0])

    # build the inputs and outputs of the network
    nodes_node, children_node, hidden_node = network.init_net(
        num_feats, len(labels))
    out_node = network.out_layer(hidden_node)

    ### init the graph
    sess = tf.Session()  #config=tf.ConfigProto(device_count={'GPU':0}))
    sess.run(tf.global_variables_initializer())

    with tf.name_scope('saver'):
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(logdir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise 'Checkpoint not found.'

    correct_labels = []
    # make predicitons from the input
    predictions = []
    step = 0
    for batch in sampling.batch_samples(
            sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1):
        nodes, children, batch_labels = batch
        # print(nodes)
        # print(children)
        # print(batch_labels)
        # return
        output = sess.run([out_node],
                          feed_dict={
                              nodes_node: nodes,
                              children_node: children,
                          })

        correct_labels.append(np.argmax(batch_labels))
        predictions.append(np.argmax(output))
        step += 1
        print(step, '/', len(trees))

    target_names = list(labels)
    print('Accuracy:', accuracy_score(correct_labels, predictions))
    print(
        classification_report(correct_labels,
                              predictions,
                              target_names=target_names))
    print(confusion_matrix(correct_labels, predictions))
コード例 #3
0
def classify_item(logdir, infile, embedfile, label_file):

    with open(label_file, 'rb') as fh:
        labels = pickle.load(fh)

    with open(infile, 'rb') as fh:
        trees, labels = pickle.load(fh)

    with open(embedfile, 'rb') as fh:
        embeddings, embed_lookup = pickle.load(fh)
        num_feats = len(embeddings[0])

    # build the inputs and outputs of the network
    nodes_node, children_node, hidden_node = network.init_net(
        num_feats, len(labels))
    out_node = network.out_layer(hidden_node)

    ### init the graph
    sess = tf.Session()  #config=tf.ConfigProto(device_count={'GPU':0}))
    sess.run(tf.global_variables_initializer())

    with tf.name_scope('saver'):
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(logdir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise 'Checkpoint not found.'

    for batch in sampling.batch_samples(
            sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1):
        nodes, children, batch_labels = batch
        # print(nodes)
        # print(children)
        # print(batch_labels)
        # return
        output = sess.run([out_node],
                          feed_dict={
                              nodes_node: nodes,
                              children_node: children,
                          })
        print(labels[np.argmax(output)])
コード例 #4
0
def train_model(logdir, infile, embedfile, epochs=EPOCHS):
    """Train a classifier to label ASTs"""

    with open(infile, 'rb') as fh:
        trees, _, labels = pickle.load(fh)

    with open(embedfile, 'rb') as fh:
        embeddings, embed_lookup = pickle.load(fh)
        num_feats = len(embeddings[0])

    # build the inputs and outputs of the network
    nodes_node, children_node, hidden_node = network.init_net(
        num_feats, len(labels))

    out_node = network.out_layer(hidden_node)
    labels_node, loss_node = network.loss_layer(hidden_node, len(labels))

    optimizer = tf.train.AdamOptimizer(LEARN_RATE)
    train_step = optimizer.minimize(loss_node)

    tf.summary.scalar('loss', loss_node)

    ### init the graph
    sess = tf.Session()  #config=tf.ConfigProto(device_count={'GPU':0}))
    sess.run(tf.global_variables_initializer())

    with tf.name_scope('saver'):
        saver = tf.train.Saver()
        summaries = tf.summary.merge_all()
        writer = tf.summary.FileWriter(logdir, sess.graph)

    checkfile = os.path.join(logdir, 'cnn_tree.ckpt')

    num_batches = len(trees) // BATCH_SIZE + (1 if len(trees) % BATCH_SIZE != 0
                                              else 0)
    for epoch in range(1, epochs + 1):
        for i, batch in enumerate(
                sampling.batch_samples(
                    sampling.gen_samples(trees, labels, embeddings,
                                         embed_lookup), BATCH_SIZE)):
            nodes, children, batch_labels = batch
            step = (epoch - 1) * num_batches + i * BATCH_SIZE

            if not nodes:
                continue  # don't try to train on an empty batch

            _, summary, err, out = sess.run(
                [train_step, summaries, loss_node, out_node],
                feed_dict={
                    nodes_node: nodes,
                    children_node: children,
                    labels_node: batch_labels
                })

            print('Epoch:', epoch, 'Step:', step, 'Loss:', err, 'Max nodes:',
                  len(nodes[0]))

            writer.add_summary(summary, step)
            if step % CHECKPOINT_EVERY == 0:
                # save state so we can resume later
                saver.save(sess, os.path.join(checkfile), step)
                print('Checkpoint saved.')

    saver.save(sess, os.path.join(checkfile), step)

    # compute the training accuracy
    correct_labels = []
    predictions = []
    print('Computing training accuracy...')
    for batch in sampling.batch_samples(
            sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1):
        nodes, children, batch_labels = batch
        output = sess.run([out_node],
                          feed_dict={
                              nodes_node: nodes,
                              children_node: children,
                          })
        correct_labels.append(np.argmax(batch_labels))
        predictions.append(np.argmax(output))

    target_names = list(labels)
    print('Accuracy:', accuracy_score(correct_labels, predictions))
    print(
        classification_report(correct_labels,
                              predictions,
                              target_names=target_names))
    print(confusion_matrix(correct_labels, predictions))
コード例 #5
0
def train_model(logdir, infile, embedfile, conv_feature, learn_rate, batch_size, epochs=50, checkpoint_every=10000):
    """Train a classifier to label ASTs"""

    with open(infile, 'rb') as fh:
        trees, _, labels = pickle.load(fh)

    with open(embedfile, 'rb') as fh:
        embeddings, embed_lookup = pickle.load(fh)
        num_feats = len(embeddings[0])

    # build the inputs and outputs of the network
    nodes_node, children_node, hidden_node = network.init_net(
        num_feats,
        conv_feature,
        len(labels)
    )

    out_node = network.out_layer(hidden_node)
    labels_node, loss_node = network.loss_layer(hidden_node, len(labels))

    optimizer = tf.train.GradientDescentOptimizer(learn_rate)
    train_step = optimizer.minimize(loss_node)

    tf.summary.scalar('loss', loss_node)

    ### init the graph
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    with tf.name_scope('saver'):
        saver = tf.train.Saver()
        summaries = tf.summary.merge_all()
        writer = tf.summary.FileWriter(logdir, sess.graph)

    checkfile = os.path.join(logdir, 'cnn_tree.ckpt')

    num_batches = len(trees) // batch_size + (1 if len(trees) % batch_size != 0 else 0)
    for epoch in range(1, epochs + 1):
        max_step = 0
        for i, batch in enumerate(sampling.batch_samples(
                sampling.gen_samples(trees, labels, embeddings, embed_lookup), batch_size
        )):
            nodes, children, batch_labels = batch
            step = ((epoch - 1) * num_batches + i)
            max_step = max(step, max_step)
            if not nodes:
                continue  # don't try to train on an empty batch

            _, summary, err, out = sess.run(
                [train_step, summaries, loss_node, out_node],
                feed_dict={
                    nodes_node: nodes,
                    children_node: children,
                    labels_node: batch_labels
                }
            )
            if i % 200 == 0:
                print('Epoch:', epoch,
                      'Step:', step,
                      'Loss:', err,
                      'Max nodes:', len(nodes[0])
                      )
                sys.stdout.flush()

            if step % checkpoint_every == 0:
                # save state so we can resume later 
                writer.add_summary(summary, step)
                saver.save(sess, os.path.join(checkfile), step)

    saver.save(sess, os.path.join(checkfile), step)

    # compute the training accuracy
    correct_labels = []
    predictions = []
    print('Computing training accuracy...')
    for batch in sampling.batch_samples(
            sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1
    ):
        nodes, children, batch_labels = batch
        output = sess.run([out_node],
                          feed_dict={
                              nodes_node: nodes,
                              children_node: children,
                          }
                          )
        correct_labels.append(np.argmax(batch_labels))
        predictions.append(np.argmax(output))

    target_names = list(labels)
    print('Accuracy:', accuracy_score(correct_labels, predictions))
    print(classification_report(correct_labels, predictions, target_names=target_names))
    print(confusion_matrix(correct_labels, predictions))