Beispiel #1
0
def train(estimator):
    tf.logging.set_verbosity(tf.logging.INFO)
    train_file_pattern = "./data/part-00000"
    data_iterator = DataIterator(params)
    train_input_fn = lambda: data_iterator.input_fn(train_file_pattern,
                                                    'offline')
    estimator.train(input_fn=train_input_fn, steps=None)
Beispiel #2
0
def predict(data_params):
    meta_path = "./model/dssm.ckpt.meta"
    ckpt_path = "./model/dssm.ckpt"
    data_file = "./data/train.txt.10"
    dssm = DSSM()
    data_iterator = DataIterator(data_params)
    iterator = data_iterator.input_fn(data_file)
    # config
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(meta_path)
        saver.restore(sess, ckpt_path)
        sess.run(tf.global_variables_initializer())
        sess.run(iterator.initializer)
        s = time.time()
        while True:
            try:
                (query_features, creative_ids, labels) = iterator.get_next()
                (batch_query, batch_creative_ids, batch_labels) = sess.run(
                    [query_features, creative_ids, labels])
                prediction = sess.run(dssm.score,
                                      feed_dict={
                                          dssm.query: batch_query,
                                          dssm.doc: batch_creative_ids
                                      })
                print(prediction)
            except tf.errors.OutOfRangeError:
                break
        e = time.time()
        # 平均每条 0.0001s
        print(e - s)
Beispiel #3
0
def train_and_eval(estimator):
    tf.logging.set_verbosity(tf.logging.INFO)
    # train_file_pattern and eval_file_pattern could be the parameters of FLAGS
    train_file_pattern = "./data/part-00000"
    eval_file_pattern = "./data/part-5"
    data_iterator = DataIterator(params)
    train_input_fn = lambda: data_iterator.input_fn(train_file_pattern,
                                                    'offline')
    eval_input_fn = lambda: data_iterator.input_fn(eval_file_pattern, 'offline'
                                                   )
    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                        max_steps=None)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn,
                                      steps=100,
                                      start_delay_secs=60,
                                      throttle_secs=30)
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
Beispiel #4
0
def eval(estimator):
    eval_file_pattern = "./data/part-5"
    data_iterator = DataIterator(params)
    eval_input_fn = lambda: data_iterator.input_fn(eval_file_pattern, 'offline'
                                                   )
    eval_results = estimator.evaluate(input_fn=eval_input_fn)
    auc_score = eval_results["auc"]
    # print(type(auc_score)) # numpy.float32
    print("\nTest auc: %.6f" % auc_score)
Beispiel #5
0
def main():
    train_file_pattern = "./data/part-00000"
    eval_file_pattern = "./data/part-5"
    data_iterator = DataIterator(params)
    train_input_fn = lambda: data_iterator.input_fn(train_file_pattern,
                                                    'offline')
    eval_input_fn = lambda: data_iterator.input_fn(eval_file_pattern, 'offline'
                                                   )
    predict_input_fn = lambda: data_iterator.input_fn(eval_file_pattern,
                                                      'offline')
    # define estimator
    # estimator = tf.estimator.Estimator(model_fn=model_fn, params=params, model_dir="./model")
    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       params=params,
                                       model_dir="./model")

    # train(estimator)
    # eval(estimator)
    train_and_eval(estimator)
Beispiel #6
0
def predict(estimator):
    predict_file = "./data/part-predict"
    data_iterator = DataIterator(params)
    with open(predict_file, 'r') as infile:
        for line in infile:
            line = line.strip('\n')
            items = line.split('\t')
            dmp_id = items[0]
            ins = "\t".join(items[1:])
            predict_input_fn = lambda: data_iterator.input_fn(ins, 'online')
            predictions = estimator.predict(input_fn=predict_input_fn)
            predictions = itertools.islice(predictions, 1)
            for i, p in enumerate(predictions):
                print("dmp_id %s: logits:%.6f probability:%.6f" %
                      (dmp_id, p["logits"], p["probabilities"]))
Beispiel #7
0
import tensorflow as tf
from data_iterator import DataIterator
# import tensorflow.contrib.eager as tfe

# tf.enable_eager_execution()

params = {
    "shuffle_buffer_size": 1000,
    "num_parallel_calls": 4,
    "epoch": 10,
    "batch_size": 4
}
data_iterator = DataIterator(params)

data_file = "./data/train.txt.10"

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
# (features, creative_ids, labels) = data_iterator.input_fn(data_file)
iterator = data_iterator.input_fn(data_file)

sess.run(iterator.initializer)
while True:
    try:
        (query_features, creative_ids, labels) = iterator.get_next()
        print(sess.run([query_features, creative_ids, labels]))
    except tf.errors.OutOfRangeError:
        break