def main(_): tf.logging.set_verbosity(tf.logging.INFO) set_dist_env() #------bulid Tasks------ model_params = { "learning_rate": FLAGS.learning_rate, "l2_reg": FLAGS.l2_reg, "deep_layers": list(map(int,FLAGS.deep_layers.split(','))), "atten_layers":list(map(int,FLAGS.atten_layers.split(','))), "dropout": list(map(float,FLAGS.dropout.split(','))), "optimizer":FLAGS.optimizer } if FLAGS.clear_existing_model: try: shutil.rmtree('./model') except Exception as e: print(e, "at clear_existing_model") else: print("existing model cleaned at %s" % FLAGS.model_dir) tr_files = "./data/train.tfrecords" va_files ="./data/test.tfrecords" fea_json = feature_json('./feature_generator.json') fg = FeatureGenerator(fea_json) md = DIN(fg) model = Model(fg,md) config = tf.estimator.RunConfig().replace(session_config = tf.ConfigProto(device_count={'GPU':0, 'CPU':FLAGS.num_threads}), log_step_count_steps=FLAGS.log_steps, save_summary_steps=FLAGS.log_steps) Estimator = tf.estimator.Estimator(model_fn=model.model_fn, model_dir='./model/', params=model_params, config=config) if FLAGS.task_type == 'train': train_spec = tf.estimator.TrainSpec(input_fn=lambda: model.input_fn(tr_files, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size)) eval_spec = tf.estimator.EvalSpec(input_fn=lambda: model.input_fn(va_files, num_epochs=1, batch_size=FLAGS.batch_size), steps=None, start_delay_secs=1000, throttle_secs=1200) tf.estimator.train_and_evaluate(Estimator, train_spec, eval_spec) elif FLAGS.task_type == 'eval': Estimator.evaluate(input_fn=lambda: model.input_fn(tr_files, num_epochs=1, batch_size=FLAGS.batch_size)) Estimator.evaluate(input_fn=lambda: model.input_fn(va_files, num_epochs=1, batch_size=FLAGS.batch_size)) elif FLAGS.task_type == 'infer': preds = Estimator.predict(input_fn=lambda: model.input_fn(va_files, num_epochs=1, batch_size=FLAGS.batch_size), predict_keys="prob") elif FLAGS.task_type == 'export': ##单机使用保存 # print(fg.feature_spec) # serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(fg.feature_spec) serving_input_receiver_fn = ( tf.estimator.export.build_raw_serving_input_receiver_fn(fg.feature_placeholders) ) Estimator.export_saved_model(FLAGS.servable_model_dir, serving_input_receiver_fn)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) set_dist_env() #------bulid Tasks------ model_params = { "learning_rate": FLAGS.learning_rate, "l2_reg": FLAGS.l2_reg, "fcn_layers": list(map(int,FLAGS.fcn_layers.split(','))), "atten_layers":list(map(int,FLAGS.atten_layers.split(','))), "auxili_layers":list(map(int,FLAGS.auxili_layers.split(','))), "dropout": list(map(float,FLAGS.dropout.split(','))), "optimizer":FLAGS.optimizer, "neg_count":FLAGS.neg_count } #### negative sampling source model_params["mid_cat"] = cPickle.load(open("data/mid_cat.pkl", "rb")) tr_files = "./data/train.tfrecords" va_files ="./data/test.tfrecords" fea_json = feature_json('./feature_generator.json') fg = FeatureGenerator(fea_json) md = DIEN(fg) model = Model(fg,md) config = tf.estimator.RunConfig().replace( session_config = tf.ConfigProto(device_count={'GPU':0, 'CPU':FLAGS.num_threads}), log_step_count_steps=FLAGS.log_steps, save_summary_steps=FLAGS.log_steps, save_checkpoints_secs=FLAGS.save_checkpoints_secs) Estimator = tf.estimator.Estimator( model_fn=model.model_fn, model_dir='./model/', params=model_params, config=config) if FLAGS.task_type == 'train': train_spec = tf.estimator.TrainSpec( input_fn=lambda: model.input_fn( tr_files, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size)) eval_spec = tf.estimator.EvalSpec( input_fn=lambda: model.input_fn( va_files, num_epochs=1, batch_size=FLAGS.batch_size), steps=None, start_delay_secs=10, throttle_secs=FLAGS.save_checkpoints_secs) tf.estimator.train_and_evaluate( Estimator, train_spec, eval_spec) elif FLAGS.task_type == 'eval': Estimator.evaluate( input_fn=lambda: model.input_fn( tr_files, num_epochs=1, batch_size=FLAGS.batch_size)) Estimator.evaluate( input_fn=lambda: model.input_fn( va_files, num_epochs=1, batch_size=FLAGS.batch_size)) elif FLAGS.task_type == 'infer': preds = Estimator.predict( input_fn=lambda: model.input_fn( va_files, num_epochs=1, batch_size=FLAGS.batch_size), predict_keys="prob") elif FLAGS.task_type == 'export': ##单机使用保存 # print(fg.feature_spec) # serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(fg.feature_spec) serving_input_receiver_fn = ( tf.estimator.export.build_raw_serving_input_receiver_fn(fg.feature_placeholders) ) Estimator.export_saved_model( FLAGS.servable_model_dir, serving_input_receiver_fn)