def test_model(test_trees, labels, embeddings, embedding_lookup, opt): logdir = opt.model_path batch_size = opt.train_batch_size epochs = opt.niter num_feats = len(embeddings[0]) random.shuffle(test_trees) # build the inputs and outputs of the network nodes_node, children_node, codecaps_node = network.init_net_treecaps( num_feats, len(labels)) out_node = network.out_layer(codecaps_node) labels_node, loss_node = network.loss_layer(codecaps_node, len(labels)) optimizer = RAdamOptimizer(opt.lr) train_step = optimizer.minimize(loss_node) sess = tf.Session() sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") saver.restore(sess, ckpt.model_checkpoint_path) for i, var in enumerate(saver._var_list): print('Var {}: {}'.format(i, var)) checkfile = os.path.join(logdir, 'tree_network.ckpt') correct_labels = [] predictions = [] print('Computing training accuracy...') for batch in sampling.batch_samples( sampling.gen_samples(test_trees, labels, embeddings, embedding_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) target_names = list(labels) print( classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions)) print('*' * 50) print('Accuracy:', accuracy_score(correct_labels, predictions)) print('*' * 50)
def predict(sess, out_node, attention_score_node, nodes_node, children_node, pkl_path, pb_path, test_trees, labels, node_ids, node_types, embeddings, embedding_lookup): for batch in sampling.batch_samples( sampling.gen_samples(test_trees, labels, embeddings, embedding_lookup), 1 ): nodes, children, batch_labels = batch output, attention_score = sess.run([out_node, attention_score_node], feed_dict={ nodes_node: nodes, children_node: children, } ) # print(output) splits = pkl_path.split(".") # node_ids = node_ids[0] # node_types = node_types[0] confidence_score = output[0] actual = str(np.argmax(batch_labels)+1) predicted = str(np.argmax(output)+1) print("Probability : " + str(confidence_score)) print("Actual classes : " + str(np.argmax(batch_labels)+1)) print("Predicted classes : " + str(np.argmax(output)+1)) max_node = len(nodes[0]) attention_score = np.reshape(attention_score, (max_node)) # print(attention_score) attention_score_map = {} for i, score in enumerate(attention_score): key = str(node_ids[i]) attention_score_map[key] = float(score) # Sort the scores attention_score_sorted = sorted(attention_score_map.items(), key=operator.itemgetter(1)) attention_score_sorted.reverse() node_ids = [] attention_score = [] for element in attention_score_sorted: node_ids.append(element[0]) attention_score.append(element[1]) attention_score_scaled = scale_attention_score_by_group(attention_score) attention_score_scaled_map = {} for i, score in enumerate(attention_score_scaled): key = str(node_ids[i]) attention_score_scaled_map[key] = float(score) return attention_score_scaled_map
def test_model(logdir, infile, embedfile): """Test a classifier to label ASTs""" with open(infile, 'rb') as fh: _, trees, labels = pickle.load(fh) with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, hidden_node = network.init_net( num_feats, len(labels)) out_node = network.out_layer(hidden_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise 'Checkpoint not found.' correct_labels = [] # make predicitons from the input predictions = [] step = 0 for batch in sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) step += 1 print(step, '/', len(trees)) target_names = list(labels) print('Accuracy:', accuracy_score(correct_labels, predictions)) print( classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions))
def train_model(train_trees, test_trees, val_trees, labels, embeddings, embedding_lookup, opt): """Train a classifier to label ASTs""" logdir = opt.model_path batch_size = opt.train_batch_size epochs = opt.niter num_feats = len(embeddings[0]) random.shuffle(train_trees) random.shuffle(val_trees) random.shuffle(test_trees) # meta_file = os.path.join(logdir, "cnn_tree.ckpt.meta") # if os.path.exists(meta_file): # saver = tf.train.import_meta_graph(meta_file) # saver.restore(sess,tf.train.latest_checkpoint('./')) # graph = tf.get_default_graph() # nodes_node = graph.get_tensor_by_name("tree:0") # children_node = graph.get_tensor_by_name("children:0") # hidden_node = graph.get_tensor_by_name("hidden_node:0") # attention_score_node = graph.get_tensor_by_name("hidden_node:0") nodes_node, children_node, hidden_node, attention_score_node = network.init_net( num_feats, len(labels), opt.aggregation) hidden_node = tf.identity(hidden_node, name="hidden_node") out_node = network.out_layer(hidden_node) labels_node, loss_node = network.loss_layer(hidden_node, len(labels)) optimizer = tf.train.AdamOptimizer(LEARN_RATE) train_step = optimizer.minimize(loss_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) # Initialize the variables (i.e. assign their default value) init = tf.global_variables_initializer() with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") saver.restore(sess, ckpt.model_checkpoint_path) for i, var in enumerate(saver._var_list): print('Var {}: {}'.format(i, var)) checkfile = os.path.join(logdir, 'cnn_tree.ckpt') if opt.training: print("Begin training..........") num_batches = len(train_trees) // batch_size + ( 1 if len(train_trees) % batch_size != 0 else 0) for epoch in range(1, epochs + 1): for i, batch in enumerate( sampling.batch_samples( sampling.gen_samples(train_trees, labels, embeddings, embedding_lookup), batch_size)): nodes, children, batch_labels = batch # print(len(batch_labels)) # print(len(batch_labels[0])) step = (epoch - 1) * num_batches + i * BATCH_SIZE if not nodes: continue # don't try to train on an empty batch # print(batch_labels) _, err, out = sess.run( [train_step, loss_node, out_node], feed_dict={ nodes_node: nodes, children_node: children, labels_node: batch_labels }) print('Epoch:', epoch, 'Step:', step, 'Loss:', err, 'Max nodes:', len(nodes[0])) # print(attention_score[0]) # print(len(attention_score[0])) # print(pooling_output.shape) if step % CHECKPOINT_EVERY == 0: # save state so we can resume later saver.save(sess, checkfile) # shutil.rmtree(savedmodel_path) print('Checkpoint saved, epoch:' + str(epoch) + ', step: ' + str(step) + ', loss: ' + str(err) + '.') correct_labels = [] predictions = [] for batch in sampling.batch_samples( sampling.gen_samples(val_trees, labels, embeddings, embedding_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) # print(output) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) target_names = list(labels) print('Accuracy:', accuracy_score(correct_labels, predictions)) print( classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions)) print("Finish all iters, storring the whole model..........") saver.save(sess, checkfile)
def train_model(logdir, infile, embedfile, training="True", testing="False"): """Train a classifier to label ASTs""" epochs = 100 print("Loading trees...") print("Training = " + training) print("Testing = " + testing) with open(infile, 'rb') as fh: trees, test_trees, labels = pickle.load(fh) random.shuffle(trees) print("Loading embeddings....") with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) print("Len embeddings : " + str(len(embeddings[0]))) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, loss_node, pooling_node, x_reconstructed = network.init_ae_net( 30, 30, 15) optimizer = tf.train.AdamOptimizer(LEARN_RATE) train_step = optimizer.minimize(loss_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") saver.restore(sess, ckpt.model_checkpoint_path) checkfile = os.path.join(logdir, 'cnn_tree.ckpt') if training == "True": print("Begin training..........") num_batches = len(trees) // BATCH_SIZE + ( 1 if len(trees) % BATCH_SIZE != 0 else 0) for epoch in range(1, epochs + 1): for i, batch in enumerate( sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), BATCH_SIZE)): nodes, children, _, _ = batch step = (epoch - 1) * num_batches + i * BATCH_SIZE if not nodes: continue # don't try to train on an empty batch # print(batch_labels) _, err, pooling_vector, x_reconstructed_vector = sess.run( [train_step, loss_node, pooling_node, x_reconstructed], feed_dict={ nodes_node: nodes, children_node: children }) print( "---------------------------------------------------------" ) # print(pooling_vector) # print("###########################") print(x_reconstructed_vector) print('Epoch:', epoch, 'Step:', step, 'Loss:', err) if step % CHECKPOINT_EVERY == 0: # save state so we can resume later saver.save(sess, os.path.join(checkfile), step) print('Checkpoint saved, epoch:' + str(epoch) + ', step: ' + str(step) + ', loss: ' + str(err) + '.') saver.save(sess, os.path.join(checkfile), step) if testing == "True": vectors = [] trees.extend(test_trees) for batch in sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), 1): nodes, children, _, batch_labels = batch x_reconstructed_vector = sess.run([x_reconstructed], feed_dict={ nodes_node: nodes, children_node: children, }) print("---------------------------------------") print(x_reconstructed_vector) vector_str = "" for v in x_reconstructed_vector[0][0]: print(str(v)) vector_str = vector_str + " " + str(v) print(vector_str) temp = str(batch_labels[0][0]) + vector_str # print(temp) # vectors_list.append(temp) vectors.append(temp) # vectors = list(set(vectors)) algos = [ "bfs", "dfs", "bubblesort", "quicksort", "mergesort", "heap", "linkedlist", "queue", "stack", "knapsack" ] vectors_with_index = [] for algo in algos: index = 0 for ele in vectors: splits = ele.split(" ") label = splits[0] vector = splits[1:] if label == algo: label = label + "_" + str(index) index = index + 1 vectors_with_index.append(label + " " + " ".join(vector)) for ele in vectors_with_index: with open( "./test_vectors/java_vectors_30D_15_ae_train_test_trees.txt", "a") as f: f.write(ele) f.write("\n")
def train_model(logdir, infile, embedfile, epochs=EPOCHS, training="True", testing="True"): """Train a classifier to label ASTs""" print("Loading trees...") with open(infile, 'rb') as fh: trees, test_trees, labels = pickle.load(fh) random.shuffle(trees) print(labels) print("Loading embeddings....") with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, hidden_node = network.init_net( num_feats, len(labels)) out_node = network.out_layer(hidden_node) labels_node, loss_node = network.loss_layer(hidden_node, len(labels)) optimizer = tf.train.AdamOptimizer(LEARN_RATE) train_step = optimizer.minimize(loss_node) tf.summary.scalar('loss', loss_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() summaries = tf.summary.merge_all() writer = tf.summary.FileWriter(logdir, sess.graph) ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") saver.restore(sess, ckpt.model_checkpoint_path) checkfile = os.path.join(logdir, 'cnn_tree.ckpt') if training == "True": print("Begin training..........") num_batches = len(trees) // BATCH_SIZE + ( 1 if len(trees) % BATCH_SIZE != 0 else 0) for epoch in range(1, epochs + 1): for i, batch in enumerate( sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), BATCH_SIZE)): nodes, children, batch_labels = batch step = (epoch - 1) * num_batches + i * BATCH_SIZE if not nodes: continue # don't try to train on an empty batch # print(batch_labels) _, summary, err, out = sess.run( [train_step, summaries, loss_node, out_node], feed_dict={ nodes_node: nodes, children_node: children, labels_node: batch_labels }) print('Epoch:', epoch, 'Step:', step, 'Loss:', err, 'Max nodes:', len(nodes[0])) writer.add_summary(summary, step) if step % CHECKPOINT_EVERY == 0: # save state so we can resume later saver.save(sess, os.path.join(checkfile), step) print('Checkpoint saved, epoch:' + str(epoch) + ', step: ' + str(step) + ', loss: ' + str(err) + '.') saver.save(sess, os.path.join(checkfile), step) # compute the training accuracy if testing == "True": correct_labels = [] predictions = [] print('Computing training accuracy...') for batch in sampling.batch_samples( sampling.gen_samples(test_trees, labels, embeddings, embed_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) #print(output) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) target_names = list(labels) print('Accuracy:', accuracy_score(correct_labels, predictions)) print( classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions))
def test_model(test_trees, labels, embeddings, embedding_lookup, opt): logdir = opt.model_path batch_size = opt.train_batch_size epochs = opt.niter node_embedding_size = len(embeddings[0]) random.shuffle(test_trees) checkfile = os.path.join(logdir, 'cnn_tree.ckpt') ckpt = tf.train.get_checkpoint_state(logdir) initializer = tf.contrib.layers.xavier_initializer() weights = { "w_t": tf.Variable(initializer([node_embedding_size, opt.feature_size]), name="w_t"), "w_l": tf.Variable(initializer([node_embedding_size, opt.feature_size]), name="w_l"), "w_r": tf.Variable(initializer([node_embedding_size, opt.feature_size]), name="w_r"), "w_attention": tf.Variable(initializer([opt.feature_size, 1]), name="w_attention") } biases = { "b_conv": tf.Variable(initializer([ opt.feature_size, ]), name="b_conv"), } nodes_node, children_node, hidden_node, attention_score_node = network.init_net( node_embedding_size, len(labels), opt.feature_size, weights, biases, opt.aggregation, opt.distributed_function) out_node = network.out_layer(hidden_node) labels_node, loss_node = network.loss_layer(hidden_node, len(labels)) optimizer = tf.train.AdamOptimizer(LEARN_RATE) train_step = optimizer.minimize(loss_node) saver = tf.train.Saver(save_relative_paths=True, max_to_keep=5) # Initialize the variables (i.e. assign their default value) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") print("Checkpoint path : " + str(ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) correct_labels = [] predictions = [] print('Computing training accuracy...') for batch in sampling.batch_samples( sampling.gen_samples(test_trees, labels, embeddings, embedding_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) # print(attention_score[0]) # print(len(attention_score[0])) # print(output) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) target_names = list(labels) print('Accuracy:', accuracy_score(correct_labels, predictions)) print( classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions))
def train_model(train_trees, val_trees, labels, embeddings, embedding_lookup, opt): """Train a classifier to label ASTs""" logdir = opt.model_path batch_size = opt.train_batch_size epochs = opt.niter node_embedding_size = len(embeddings[0]) random.shuffle(train_trees) random.shuffle(val_trees) # random.shuffle(test_trees) checkfile = os.path.join(logdir, 'cnn_tree.ckpt') ckpt = tf.train.get_checkpoint_state(logdir) initializer = tf.contrib.layers.xavier_initializer() weights = { "w_t": tf.Variable(initializer([node_embedding_size, opt.feature_size]), name="w_t"), "w_l": tf.Variable(initializer([node_embedding_size, opt.feature_size]), name="w_l"), "w_r": tf.Variable(initializer([node_embedding_size, opt.feature_size]), name="w_r"), "w_attention": tf.Variable(initializer([opt.feature_size, 1]), name="w_attention") } biases = { "b_conv": tf.Variable(initializer([ opt.feature_size, ]), name="b_conv"), } nodes_node, children_node, hidden_node, attention_score_node = network.init_net( node_embedding_size, len(labels), opt.feature_size, weights, biases, opt.aggregation, opt.distributed_function) out_node = network.out_layer(hidden_node) labels_node, loss_node = network.loss_layer(hidden_node, len(labels)) optimizer = tf.train.AdamOptimizer(LEARN_RATE) train_step = optimizer.minimize(loss_node) saver = tf.train.Saver(save_relative_paths=True, max_to_keep=5) # Initialize the variables (i.e. assign their default value) init = tf.global_variables_initializer() if opt.training: print("Begin training..........") with tf.Session() as sess: sess.run(init) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") print("Checkpoint path : " + str(ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) for i, var in enumerate(saver._var_list): print('Var {}: {}'.format(i, var)) # saved_model.loader.load(sess, [tag_constants.TRAINING], savedmodel_path) num_batches = len(train_trees) // batch_size + ( 1 if len(train_trees) % batch_size != 0 else 0) for epoch in range(1, epochs + 1): for i, batch in enumerate( sampling.batch_samples( sampling.gen_samples(train_trees, labels, embeddings, embedding_lookup), batch_size)): nodes, children, batch_labels = batch # print(len(batch_labels)) # print(len(batch_labels[0])) step = (epoch - 1) * num_batches + i * BATCH_SIZE if not nodes: continue # don't try to train on an empty batch # print(batch_labels) _, err, out = sess.run( [train_step, loss_node, out_node], feed_dict={ nodes_node: nodes, children_node: children, labels_node: batch_labels }) print('Epoch:', epoch, 'Step:', step, 'Loss:', err, 'Max nodes:', len(nodes[0])) if step % CHECKPOINT_EVERY == 0: # save state so we can resume later saver.save(sess, checkfile) # shutil.rmtree(savedmodel_path) print('Checkpoint saved, epoch:' + str(epoch) + ', step: ' + str(step) + ', loss: ' + str(err) + '.') correct_labels = [] predictions = [] for batch in sampling.batch_samples( sampling.gen_samples(val_trees, labels, embeddings, embedding_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children, }) # print(output) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) target_names = list(labels) print('Accuracy:', accuracy_score(correct_labels, predictions)) print( classification_report(correct_labels, predictions, target_names=target_names)) print(confusion_matrix(correct_labels, predictions)) print("Finish all iters, storring the whole model..........") saver.save(sess, checkfile)
def train_model(train_trees, val_trees, labels, embeddings, embedding_lookup, opt): max_acc = 0.0 logdir = opt.model_path batch_size = opt.train_batch_size epochs = opt.niter num_feats = len(embeddings[0]) random.shuffle(train_trees) nodes_node, children_node, codecaps_node = network.init_net_treecaps( num_feats, len(labels)) codecaps_node = tf.identity(codecaps_node, name="codecaps_node") out_node = network.out_layer(codecaps_node) labels_node, loss_node = network.loss_layer(codecaps_node, len(labels)) optimizer = RAdamOptimizer(opt.lr) train_step = optimizer.minimize(loss_node) ### init the graph sess = tf.Session() sess.run(tf.global_variables_initializer()) # Initialize the variables (i.e. assign their default value) init = tf.global_variables_initializer() with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") saver.restore(sess, ckpt.model_checkpoint_path) for i, var in enumerate(saver._var_list): print('Var {}: {}'.format(i, var)) checkfile = os.path.join(logdir, 'tree_network.ckpt') print("Begin training..........") num_batches = len(train_trees) // batch_size + ( 1 if len(train_trees) % batch_size != 0 else 0) for epoch in range(1, epochs + 1): bar = progressbar.ProgressBar(maxval=len(train_trees), widgets=[ progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage() ]) bar.start() for i, batch in enumerate( sampling.batch_samples( sampling.gen_samples(train_trees, labels, embeddings, embedding_lookup), batch_size)): nodes, children, batch_labels = batch step = (epoch - 1) * num_batches + i * batch_size if not nodes: continue _, err, out = sess.run( [train_step, loss_node, out_node], feed_dict={ nodes_node: nodes, children_node: children, labels_node: batch_labels }) bar.update(i + 1) bar.finish() correct_labels = [] predictions = [] logits = [] for batch in sampling.batch_samples( sampling.gen_samples(val_trees, labels, embeddings, embedding_lookup), 1): nodes, children, batch_labels = batch output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children }) correct_labels.append(np.argmax(batch_labels)) predictions.append(np.argmax(output)) logits.append(output) target_names = list(labels) acc = accuracy_score(correct_labels, predictions) if (acc > max_acc): max_acc = acc saver.save(sess, checkfile) np.save(opt.model_path + '/logits', np.array(logits)) np.save(opt.model_path + '/correct', np.array(correct_labels)) print('Epoch', str(epoch), 'Accuracy:', acc, 'Max Acc: ', max_acc) csv_log.write(str(epoch) + ',' + str(acc) + ',' + str(max_acc) + '\n') print("Finish all iters, storring the whole model..........")
def train_model(logdir, infile, embedfile, training="True", testing="False"): """Train a classifier to label ASTs""" epochs = 100 print("Loading trees...") print("Training = " + training) print("testing = " + testing) with open(infile, 'rb') as fh: trees, test_trees, labels = pickle.load(fh) random.shuffle(trees) print("Loading embeddings....") with open(embedfile, 'rb') as fh: embeddings, embed_lookup = pickle.load(fh) print("Len embeddings : " + str(len(embeddings[0]))) num_feats = len(embeddings[0]) # build the inputs and outputs of the network nodes_node, children_node, loss_node, pooling_node, x_reconstructed = network.init_vae_net( 30, 100) optimizer = tf.train.AdamOptimizer(LEARN_RATE) train_step = optimizer.minimize(loss_node) ### init the graph sess = tf.Session() #config=tf.ConfigProto(device_count={'GPU':0})) sess.run(tf.global_variables_initializer()) with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") saver.restore(sess, ckpt.model_checkpoint_path) checkfile = os.path.join(logdir, 'cnn_tree.ckpt') if training == "True": print("Begin training..........") num_batches = len(trees) // BATCH_SIZE + ( 1 if len(trees) % BATCH_SIZE != 0 else 0) for epoch in range(1, epochs + 1): for i, batch in enumerate( sampling.batch_samples( sampling.gen_samples(trees, labels, embeddings, embed_lookup), BATCH_SIZE)): nodes, children, _, _ = batch step = (epoch - 1) * num_batches + i * BATCH_SIZE if not nodes: continue # don't try to train on an empty batch # print(batch_labels) _, err, pooling_vector, x_reconstructed_vector = sess.run( [train_step, loss_node, pooling_node, x_reconstructed], feed_dict={ nodes_node: nodes, children_node: children }) print( "---------------------------------------------------------" ) # print(pooling_vector) # print("###########################") print(x_reconstructed_vector) print('Epoch:', epoch, 'Step:', step, 'Loss:', err) if step % CHECKPOINT_EVERY == 0: # save state so we can resume later saver.save(sess, os.path.join(checkfile), step) print('Checkpoint saved, epoch:' + str(epoch) + ', step: ' + str(step) + ', loss: ' + str(err) + '.') saver.save(sess, os.path.join(checkfile), step) if testing == "True": for batch in sampling.batch_samples( sampling.gen_samples(test_trees, labels, embeddings, embed_lookup), 1): nodes, children, _, _ = batch x_reconstructed_vector = sess.run([x_reconstructed], feed_dict={ nodes_node: nodes, children_node: children, }) print("---------------------------------------") print(x_reconstructed_vector)
def train_model(train_trees, val_trees, labels, embedding_lookup, opt): max_acc = 0.0 logdir = opt.model_path batch_size = opt.train_batch_size epochs = opt.niter random.shuffle(train_trees) nodes_node, children_node, codecaps_node = network.init_net_treecaps(50, embedding_lookup, len(labels)) codecaps_node = tf.identity(codecaps_node, name="codecaps_node") out_node = network.out_layer(codecaps_node) labels_node, loss_node = network.loss_layer(codecaps_node, len(labels)) optimizer = RAdamOptimizer(opt.lr) train_point = optimizer.minimize(loss_node) ### init the graph sess = tf.Session() sess.run(tf.global_variables_initializer()) # Initialize the variables (i.e. assign their default value) init = tf.global_variables_initializer() with tf.name_scope('saver'): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: print("Continue training with old model") saver.restore(sess, ckpt.model_checkpoint_path) for i, var in enumerate(saver._var_list): print('Var {}: {}'.format(i, var)) checkfile = os.path.join(logdir, 'tree_network.ckpt') print("Begin training..........") num_batches = len(train_trees) // batch_size + (1 if len(train_trees) % batch_size != 0 else 0) max_acc = 0.0 for epoch in range(1, epochs+1): for train_step, train_batch in enumerate(sampling.batch_samples( sampling.gen_samples(train_trees, labels), batch_size )): nodes, children, batch_labels = train_batch # step = (epoch - 1) * num_batches + train_step * batch_size if not nodes: continue _, err, out = sess.run( [train_point, loss_node, out_node], feed_dict={ nodes_node: nodes, children_node: children, labels_node: batch_labels } ) print("Epoch : ", str(epoch), "Step : ", train_step, "Loss : ", err, "Max Acc: ",max_acc) if train_step % 1000 == 0 and train_step > 0: correct_labels = [] predictions = [] # logits = [] for test_batch in sampling.batch_samples( sampling.gen_samples(val_trees, labels), batch_size ): print("---------------") nodes, children, batch_labels = test_batch print(batch_labels) output = sess.run([out_node], feed_dict={ nodes_node: nodes, children_node: children } ) batch_correct_labels = np.argmax(batch_labels, axis=1) batch_predictions = np.argmax(output[0], axis=1) correct_labels.extend(batch_correct_labels) predictions.extend(batch_predictions) # logits.append(output) print(batch_correct_labels) print(batch_predictions) acc = accuracy_score(correct_labels, predictions) if (acc>max_acc): max_acc = acc saver.save(sess, checkfile) print("Saved checkpoint....") print('Epoch',str(epoch),'Accuracy:', acc, 'Max Acc: ',max_acc) csv_log.write(str(epoch)+','+str(acc)+','+str(max_acc)+'\n') print("Finish all iters, storring the whole model..........")