def train(estimator): tf.logging.set_verbosity(tf.logging.INFO) train_file_pattern = "./data/part-00000" data_iterator = DataIterator(params) train_input_fn = lambda: data_iterator.input_fn(train_file_pattern, 'offline') estimator.train(input_fn=train_input_fn, steps=None)
def predict(data_params): meta_path = "./model/dssm.ckpt.meta" ckpt_path = "./model/dssm.ckpt" data_file = "./data/train.txt.10" dssm = DSSM() data_iterator = DataIterator(data_params) iterator = data_iterator.input_fn(data_file) # config with tf.Session() as sess: saver = tf.train.import_meta_graph(meta_path) saver.restore(sess, ckpt_path) sess.run(tf.global_variables_initializer()) sess.run(iterator.initializer) s = time.time() while True: try: (query_features, creative_ids, labels) = iterator.get_next() (batch_query, batch_creative_ids, batch_labels) = sess.run( [query_features, creative_ids, labels]) prediction = sess.run(dssm.score, feed_dict={ dssm.query: batch_query, dssm.doc: batch_creative_ids }) print(prediction) except tf.errors.OutOfRangeError: break e = time.time() # 平均每条 0.0001s print(e - s)
def train_and_eval(estimator): tf.logging.set_verbosity(tf.logging.INFO) # train_file_pattern and eval_file_pattern could be the parameters of FLAGS train_file_pattern = "./data/part-00000" eval_file_pattern = "./data/part-5" data_iterator = DataIterator(params) train_input_fn = lambda: data_iterator.input_fn(train_file_pattern, 'offline') eval_input_fn = lambda: data_iterator.input_fn(eval_file_pattern, 'offline' ) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=None) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=100, start_delay_secs=60, throttle_secs=30) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def eval(estimator): eval_file_pattern = "./data/part-5" data_iterator = DataIterator(params) eval_input_fn = lambda: data_iterator.input_fn(eval_file_pattern, 'offline' ) eval_results = estimator.evaluate(input_fn=eval_input_fn) auc_score = eval_results["auc"] # print(type(auc_score)) # numpy.float32 print("\nTest auc: %.6f" % auc_score)
def main(): train_file_pattern = "./data/part-00000" eval_file_pattern = "./data/part-5" data_iterator = DataIterator(params) train_input_fn = lambda: data_iterator.input_fn(train_file_pattern, 'offline') eval_input_fn = lambda: data_iterator.input_fn(eval_file_pattern, 'offline' ) predict_input_fn = lambda: data_iterator.input_fn(eval_file_pattern, 'offline') # define estimator # estimator = tf.estimator.Estimator(model_fn=model_fn, params=params, model_dir="./model") estimator = tf.estimator.Estimator(model_fn=model_fn, params=params, model_dir="./model") # train(estimator) # eval(estimator) train_and_eval(estimator)
def predict(estimator): predict_file = "./data/part-predict" data_iterator = DataIterator(params) with open(predict_file, 'r') as infile: for line in infile: line = line.strip('\n') items = line.split('\t') dmp_id = items[0] ins = "\t".join(items[1:]) predict_input_fn = lambda: data_iterator.input_fn(ins, 'online') predictions = estimator.predict(input_fn=predict_input_fn) predictions = itertools.islice(predictions, 1) for i, p in enumerate(predictions): print("dmp_id %s: logits:%.6f probability:%.6f" % (dmp_id, p["logits"], p["probabilities"]))
import tensorflow as tf from data_iterator import DataIterator # import tensorflow.contrib.eager as tfe # tf.enable_eager_execution() params = { "shuffle_buffer_size": 1000, "num_parallel_calls": 4, "epoch": 10, "batch_size": 4 } data_iterator = DataIterator(params) data_file = "./data/train.txt.10" sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) # (features, creative_ids, labels) = data_iterator.input_fn(data_file) iterator = data_iterator.input_fn(data_file) sess.run(iterator.initializer) while True: try: (query_features, creative_ids, labels) = iterator.get_next() print(sess.run([query_features, creative_ids, labels])) except tf.errors.OutOfRangeError: break