def dev_point_wise(): if FLAGS.data == 'TREC' or FLAGS.data == 'sst2': train, dev, test = load_trec_sst2(FLAGS.data) else: train, dev = load(FLAGS.data) q_max_sent_length = max( map(lambda x: len(x), train['question'].str.split())) print(q_max_sent_length) print(len(train)) print('train question unique:{}'.format(len(train['question'].unique()))) print('train length', len(train)) print('dev length', len(dev)) if FLAGS.data == 'TREC' or FLAGS.data == 'sst2': alphabet, embeddings = prepare([train, dev, test], max_sent_length=q_max_sent_length, dim=FLAGS.embedding_dim, is_embedding_needed=True, fresh=True) else: alphabet, embeddings = prepare([train, dev], max_sent_length=q_max_sent_length, dim=FLAGS.embedding_dim, is_embedding_needed=True, fresh=True) print('alphabet:', len(alphabet)) with tf.Graph().as_default(): with tf.device("/gpu:0"): session_conf = tf.ConfigProto() session_conf.allow_soft_placement = FLAGS.allow_soft_placement session_conf.log_device_placement = FLAGS.log_device_placement session_conf.gpu_options.allow_growth = True sess = tf.Session(config=session_conf) now = int(time.time()) timeArray = time.localtime(now) timeStamp1 = time.strftime("%Y%m%d%H%M%S", timeArray) timeDay = time.strftime("%Y%m%d", timeArray) print(timeStamp1) with sess.as_default(), open(precision, "w") as log: log.write(str(FLAGS.__flags) + '\n') cnn = CNN(max_input_left=q_max_sent_length, vocab_size=len(alphabet), embeddings=embeddings, embedding_size=FLAGS.embedding_dim, batch_size=FLAGS.batch_size, filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), num_filters=FLAGS.num_filters, l2_reg_lambda=FLAGS.l2_reg_lambda, is_Embedding_Needed=True, trainable=FLAGS.trainable, dataset=FLAGS.data, extend_feature_dim=FLAGS.extend_feature_dim) cnn.build_graph() global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) grads_and_vars = optimizer.compute_gradients(cnn.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) sess.run(tf.global_variables_initializer()) acc_max = 0.0000 for i in range(FLAGS.num_epochs): datas = batch_gen_with_point_wise(train, alphabet, FLAGS.batch_size, q_len=q_max_sent_length) for data in datas: feed_dict = { cnn.question: data[0], cnn.input_y: data[1], cnn.q_position: data[2], cnn.dropout_keep_prob: FLAGS.dropout_keep_prob } _, step, loss, accuracy = sess.run( [train_op, global_step, cnn.loss, cnn.accuracy], feed_dict) time_str = datetime.datetime.now().isoformat() print("{}: step {}, loss {:g}, acc {:g} ".format( time_str, step, loss, accuracy)) predicted = predict(sess, cnn, train, alphabet, FLAGS.batch_size, q_max_sent_length) predicted_label = np.argmax(predicted, 1) acc_train = accuracy_score(predicted_label, train['flag']) predicted_dev = predict(sess, cnn, dev, alphabet, FLAGS.batch_size, q_max_sent_length) predicted_label = np.argmax(predicted_dev, 1) acc_dev = accuracy_score(predicted_label, dev['flag']) if acc_dev > acc_max: tf.train.Saver().save(sess, "model_save/model", write_meta_graph=True) acc_max = acc_dev print("{}:train epoch:acc {}".format(i, acc_train)) print("{}:dev epoch:acc {}".format(i, acc_dev)) line2 = " {}:epoch: acc{}".format(i, acc_dev) log.write(line2 + '\n') log.flush() acc_flod.append(acc_max) log.close()
line2 = " {}:epoch: acc{}".format(i, acc_dev) log.write(line2 + '\n') log.flush() acc_flod.append(acc_max) log.close() if __name__ == '__main__': if FLAGS.data == 'TREC' or FLAGS.data == 'sst2': for attr, value in sorted(FLAGS.__flags.items()): print(("{}={}".format(attr.upper(), value))) dev_point_wise() ckpt = tf.train.get_checkpoint_state("model_save" + '/') saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta') train, dev, test = load_trec_sst2(FLAGS.data) q_max_sent_length = max( map(lambda x: len(x), train['question'].str.split())) alphabet, embeddings = prepare([train, test, dev], max_sent_length=q_max_sent_length, dim=FLAGS.embedding_dim, is_embedding_needed=True, fresh=True) with tf.Session() as sess: saver.restore(sess, ckpt.model_checkpoint_path) graph = tf.get_default_graph() scores = [] question = graph.get_operation_by_name('input_question').outputs[0] q_position = graph.get_operation_by_name('q_position').outputs[0] dropout_keep_prob = graph.get_operation_by_name( 'dropout_keep_prob').outputs[0]