def predict_model(args, logdir, infile, embedfile): """Test a classifier to label ASTs""" with open(infile, 'rb') as fh: _, trees, cv, labels = pickle.load(fh) with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, hidden_node = network.init_net( num_feats, len(labels)) out_node = network.out_layer(hidden_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise 'Checkpoint not found.' correct_labels = [] # make predicitons from the input predictions = [] step = 0 for batch in sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) # print(step, '/', len(trees)) orig_num = np.argmax(batch_labels) orig_label = labels[orig_num] + ' \t' orig_label = "" pred_num = np.argmax(output) pred_label = labels[pred_num] pred_score = output[0][0][pred_num] pred_item = str(step + 1) + '/' + str(len(trees)) if 'meta' in trees[step].keys() and 'name' in trees[step]['meta'].keys( ): pred_item += " " + trees[step]['meta']['name'] print(pred_label + ' \t' + str(pred_score) + ' \t' + orig_label + pred_item) step += 1
def test_model(logdir, infile, embedfile): """Test a classifier to label ASTs""" with open(infile, 'rb') as fh: _, trees, labels = pickle.load(fh) with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, hidden_node = network.init_net( num_feats, len(labels)) out_node = network.out_layer(hidden_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise 'Checkpoint not found.' correct_labels = [] # make predicitons from the input predictions = [] step = 0 for batch in sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1): nodes, children, batch_labels = batch # print(nodes) # print(children) # print(batch_labels) # return output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) step += 1 print(step, '/', len(trees)) target_names = list(labels) print('Accuracy:', accuracy_score(correct_labels, predictions)) print( classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions))
def classify_item(logdir, infile, embedfile, label_file): with open(label_file, 'rb') as fh: labels = pickle.load(fh) with open(infile, 'rb') as fh: trees, labels = pickle.load(fh) with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, hidden_node = network.init_net( num_feats, len(labels)) out_node = network.out_layer(hidden_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise 'Checkpoint not found.' for batch in sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1): nodes, children, batch_labels = batch # print(nodes) # print(children) # print(batch_labels) # return output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) print(labels[np.argmax(output)])
import numpy as np import tensorflow as tf import classifier.tbcnn.network as network import classifier.tbcnn.sampling as sampling from sklearn.metrics import classification_report, confusion_matrix, accuracy_score def classify_item(logdir, infile, embedfile): with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, hidden_node = network.init_net( num_feats, len(labels) ) out_node = network.out_layer(hidden_node) ### init the graph sess = tf.Session()#config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise 'Checkpoint not found.'
def train_model(logdir, infile, embedfile, epochs=EPOCHS): """Train a classifier to label ASTs""" with open(infile, 'rb') as fh: trees, _, labels = pickle.load(fh) with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, hidden_node = network.init_net( num_feats, len(labels)) out_node = network.out_layer(hidden_node) labels_node, loss_node = network.loss_layer(hidden_node, len(labels)) optimizer = tf.train.AdamOptimizer(LEARN_RATE) train_step = optimizer.minimize(loss_node) tf.summary.scalar('loss', loss_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() summaries = tf.summary.merge_all() writer = tf.summary.FileWriter(logdir, sess.graph) checkfile = os.path.join(logdir, 'cnn_tree.ckpt') num_batches = len(trees) // BATCH_SIZE + (1 if len(trees) % BATCH_SIZE != 0 else 0) for epoch in range(1, epochs + 1): for i, batch in enumerate( sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), BATCH_SIZE)): nodes, children, batch_labels = batch step = (epoch - 1) * num_batches + i * BATCH_SIZE if not nodes: continue # don't try to train on an empty batch _, summary, err, out = sess.run( [train_step, summaries, loss_node, out_node], feed_dict={ nodes_node: nodes, children_node: children, labels_node: batch_labels }) print('Epoch:', epoch, 'Step:', step, 'Loss:', err, 'Max nodes:', len(nodes[0])) writer.add_summary(summary, step) if step % CHECKPOINT_EVERY == 0: # save state so we can resume later saver.save(sess, os.path.join(checkfile), step) print('Checkpoint saved.') saver.save(sess, os.path.join(checkfile), step) # compute the training accuracy correct_labels = [] predictions = [] print('Computing training accuracy...') for batch in sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) target_names = list(labels) print('Accuracy:', accuracy_score(correct_labels, predictions)) print( classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions))
def train_model(logdir, infile, embedfile, conv_feature, learn_rate, batch_size, epochs=50, checkpoint_every=10000): """Train a classifier to label ASTs""" with open(infile, 'rb') as fh: trees, _, labels = pickle.load(fh) with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, hidden_node = network.init_net( num_feats, conv_feature, len(labels) ) out_node = network.out_layer(hidden_node) labels_node, loss_node = network.loss_layer(hidden_node, len(labels)) optimizer = tf.train.GradientDescentOptimizer(learn_rate) train_step = optimizer.minimize(loss_node) tf.summary.scalar('loss', loss_node) ### init the graph config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() summaries = tf.summary.merge_all() writer = tf.summary.FileWriter(logdir, sess.graph) checkfile = os.path.join(logdir, 'cnn_tree.ckpt') num_batches = len(trees) // batch_size + (1 if len(trees) % batch_size != 0 else 0) for epoch in range(1, epochs + 1): max_step = 0 for i, batch in enumerate(sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), batch_size )): nodes, children, batch_labels = batch step = ((epoch - 1) * num_batches + i) max_step = max(step, max_step) if not nodes: continue # don't try to train on an empty batch _, summary, err, out = sess.run( [train_step, summaries, loss_node, out_node], feed_dict={ nodes_node: nodes, children_node: children, labels_node: batch_labels } ) if i % 200 == 0: print('Epoch:', epoch, 'Step:', step, 'Loss:', err, 'Max nodes:', len(nodes[0]) ) sys.stdout.flush() if step % checkpoint_every == 0: # save state so we can resume later writer.add_summary(summary, step) saver.save(sess, os.path.join(checkfile), step) saver.save(sess, os.path.join(checkfile), step) # compute the training accuracy correct_labels = [] predictions = [] print('Computing training accuracy...') for batch in sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1 ): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, } ) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) target_names = list(labels) print('Accuracy:', accuracy_score(correct_labels, predictions)) print(classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions))