def DevLoss(dev_x, dev_y, word_index_dict_size, dropout_keep_prob, embedding_dim, filter_sizes, num_filters, l2_reg_lambda, cpus): x, y = data.LoadDevData(dev_path) x_size = len(y) dev_loss, dev_acc = BatchLoss(dev_x, dev_y, word_index_dict_size, dropout_keep_prob, embedding_dim, filter_sizes, num_filters, x_size, l2_reg_lambda, cpus) return dev_loss, dev_acc
def DevLoss(text_length, dev_x, dev_y, word_index_dict_size, dropout_keep_prob, embedding_dim, filter_sizes, num_filters, l2_reg_lambda, cpus): x, y = data.LoadDevData(dev_path) x_size = len(y) tf.get_variable_scope().reuse_variables() dev_loss, dev_acc = BatchLoss(text_length, dev_x, dev_y, word_index_dict_size, dropout_keep_prob, embedding_dim, filter_sizes, num_filters, x_size, l2_reg_lambda, cpus) return dev_loss, dev_acc
def Train(train_path, dev_path, \ word_index_dict_size, text_length, embedding_dim, \ filter_sizes, num_filters, l2_reg_lambda, batch_size, \ log_device_placement, model_dir, dropout_keep_prob_value, \ num_epochs, evaluate_every, checkpoint_every, gpus, cpus): g1 = tf.Graph() valid_max_accuracy = 0 with g1.as_default(), tf.device("/gpu:" + str(gpus[0])): config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=log_device_placement) config.gpu_options.allow_growth = True config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=config) global_step = tf.get_variable('global_step', [], \ initializer=tf.constant_initializer(0), trainable=False) opt = tf.train.AdamOptimizer(1e-3) #opt=tf.train.GradientDescentOptimizer(1e-3) input_x = tf.placeholder(tf.int32, [None, text_length], name="input_x") input_y = tf.placeholder(tf.float32, [None, 9], name="input_y") # For online predicting. dev_x = tf.placeholder(tf.int32, [None, text_length], name="dev_x") dev_y = tf.placeholder(tf.float32, [None, 9], name="dev_y") dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") train_op, average_loss, average_accuracy = MultiGPUCalcLossAccu( text_length, gpus, input_x, input_y, batch_size, word_index_dict_size, dropout_keep_prob, embedding_dim, filter_sizes, num_filters, l2_reg_lambda, opt, global_step, cpus) dev_loss, dev_acc = DevLoss(text_length, dev_x, dev_y, word_index_dict_size, dropout_keep_prob, embedding_dim, filter_sizes, num_filters, l2_reg_lambda, cpus) # Session saver = tf.train.Saver(tf.global_variables(), max_to_keep=2) ckpt = tf.train.get_checkpoint_state(model_dir) if reload_model and ckpt and ckpt.model_checkpoint_path: #_saver=tf.train.Saver(tf.global_variables() saver.restore(sess, ckpt.model_checkpoint_path) print "++++++++++++++++++++++****************" print sess.run(global_step) else: init = tf.global_variables_initializer() sess.run(init) #config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=log_device_placement) #config.gpu_options.allow_growth = True #config.gpu_options.allocator_type = 'BFC' #sess = tf.Session(config=config) ''' sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=log_device_placement)) ''' #sess.run(init) tf.train.start_queue_runners(sess=sess) # Output directory for models and summaries out_dir = os.path.abspath(os.path.join(os.path.curdir, model_dir)) print("Writing to {}\n".format(out_dir)) # Summaries for loss and accuracy loss_summary = tf.summary.scalar("loss", average_loss) acc_summary = tf.summary.scalar("accuracy", average_accuracy) dev_loss_summary = tf.summary.scalar("loss", dev_loss) dev_acc_summary = tf.summary.scalar("accuracy", dev_acc) # Train Summaries print "Train Summaries" train_summary_op = tf.summary.merge([loss_summary, acc_summary]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) # Dev summaries print "Dev summaries" dev_summary_op = tf.summary.merge([dev_loss_summary, dev_acc_summary]) dev_summary_dir = os.path.join(out_dir, "summaries", "dev") dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) for epoch in range(num_epochs): train_iter = data.BatchIter(train_path, batch_size) for train_batch in train_iter: x_batch, y_batch = zip(*train_batch) train_feed_dict = { input_x: x_batch, input_y: y_batch, dropout_keep_prob: dropout_keep_prob_value } """ TrainBatch(train_feed_dict, sess, train_op, global_step, \ train_summary_op, average_loss, average_accuracy, \ train_summary_writer) """ TrainStep(sess, train_feed_dict, train_op, global_step, train_summary_op, average_loss, average_accuracy, train_summary_writer) current_step = tf.train.global_step(sess, global_step) if current_step % evaluate_every == 0: print("\nEvaluation:") x, y = data.LoadDevData(dev_path) dev_feed_dict = { dev_x: x, dev_y: y, dropout_keep_prob: 1.0 } #print("1111") valid_accuracy = DevStep(sess, dev_feed_dict, dev_summary_op, dev_loss, dev_acc, current_step, dev_summary_writer) print("") current_step = tf.train.global_step(sess, global_step) if current_step % checkpoint_every == 0: if valid_accuracy > valid_max_accuracy: valid_max_accuracy = valid_accuracy checkpoint_path = os.path.join(model_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=current_step) # Remove the old checkpoint files. #RemovePBXStepsBefore(model_dir, 'cnn', 4, current_step, checkpoint_every) # Save the latest checkpoint. """ #save old meta data. checkpoint_path = os.path.join(model_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=current_step) """ #SavePB(sess, g1, os.path.join(model_dir, 'models'), 'cnn', \ # current_step=str(current_step)) #SavePB(sess, g1, os.path.join(model_dir, 'models'), 'cnn') sess.close()
# Parameters #tf.flags.DEFINE_string("test_data", "../gen_class_data/test_data/0617tag-3-20.test", "") tf.flags.DEFINE_string( "test_data", "/mnt/workspace/yezhenxu/data/termlevel/len40/test/topic.s2id", "") #tf.flags.DEFINE_string("test_data", "../gen_class_data/train-0614-20/train-0613-test.seg.id", "") #tf.flags.DEFINE_string("model_dir", "./model_test0617-20/models/cnn.pb10000", "") tf.flags.DEFINE_string( "model_dir", "/mnt/workspace/yezhenxu/model/term_model/len40/models/cnn.pb10000", "") FLAGS = tf.flags.FLAGS FLAGS._parse_flags() x, y = data.LoadDevData(FLAGS.test_data) print "**********************" print x print "*****************************" print y cat_num_ary = [0, 0, 0, 0, 0, 0, 0, 0, 0] cat_right_num_ary = [0, 0, 0, 0, 0, 0, 0, 0, 0] cat_max_ary = [0, 0, 0, 0, 0, 0, 0, 0, 0] cat_min_ary = [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1] cat_total_ary = [0, 0, 0, 0, 0, 0, 0, 0, 0] cat_recall_ary = [0, 0, 0, 0, 0, 0, 0, 0, 0] cat_num_total = 0 for i in range(len(y)): index = get_index(y[i]) if index < 0: