def main(_): params = ModelParams() for key, value in params.__dict__.items(): print(key, "=", value) if params.alg_name == "dnn": model = dnn.model_estimator(params) elif params.alg_name == "dnn_pool": model = dnn_pool.model_estimator(params) elif params.alg_name == "dnn_cate": model = dnn_cate.model_estimator(params) elif params.alg_name == "deepfm": model = deepfm.model_estimator(params) elif params.alg_name == "deepfm_pool": model = deepfm_pool.model_estimator(params) elif params.alg_name == "deepfm_cate": model = deepfm_cate.model_estimator(params) elif params.alg_name == "din": model = din.model_estimator(params) elif params.alg_name == "dinfm": model = dinfm.model_estimator(params) elif params.alg_name == "dien": model = dien.model_estimator(params) elif params.alg_name == "dnn_emb": model = dnn_emb.model_estimator(params) elif params.alg_name == "dnn_autoint": model = dnn_autoint.model_estimator(params) else: model = dnn.model_estimator(params) print("alg_name = %s is error" % params.alg_name) exit(-1) if params.action_type == "train": start_time = time.time() train_files = data_load.get_file_list(params.train_path) predict_files = data_load.get_file_list(params.predict_path) print("--------------train------------") trained_model_path = op.model_fit(model, params, train_files, predict_files) end_time = time.time() print("model_save training time: %.2f s" % (end_time - start_time)) # save model_pb path to a file f = tf.gfile.GFile(params.model_pb + "/test", 'w') f.write(str(trained_model_path, encoding="utf-8")) print("--------------predict------------") op.model_predict(trained_model_path, predict_files, params) elif params.action_type == "pred": print("--------------predict------------") predict_files = data_load.get_file_list(params.predict_path) op.model_predict( '/Users/R.Stalker/PycharmProjects/deep_learing_estimator/files/model_save_pb/deepfm', predict_files, params) else: print("action_type = %s is error !!!" % params.action_type)
def main(_): params = ModelParams() for key, value in params.__dict__.items(): print(key, "=", value) print("---delete old data...") delete_dt = my_utils.shift_date_time(params.dt, -1) print("---delete_dt:", delete_dt) print(params.train_path[:-9] + delete_dt) print(params.predict_path[:-9] + delete_dt) shutil.rmtree(params.train_path[:-9] + delete_dt, ignore_errors=True) shutil.rmtree(params.predict_path[:-9] + delete_dt, ignore_errors=True) if params.alg_name == "dnn": model = dnn.model_estimator(params) elif params.alg_name == "deepfm": model = deepfm.model_estimator(params) elif params.alg_name == "deepfm_pool": model = deepfm_pool.model_estimator(params) elif params.alg_name == "din": model = din.model_estimator(params) elif params.alg_name == "dnn_pool": model = dnn_pool.model_estimator(params) elif params.alg_name == "dinfm": model = dinfm.model_estimator(params) elif params.alg_name == "dien": model = dien.model_estimator(params) elif params.alg_name == "dnn_autoint": model = dnn_autoint.model_estimator(params) else: model = dnn.model_estimator(params) print("alg_name = %s is error" % params.alg_name) exit(-1) if params.mode == "train": start_time = time.time() train_files = data_load.get_file_list(params.train_path) predict_files = data_load.get_file_list(params.predict_path) print("--------------train------------") trained_model_path = op.model_fit(model, params, train_files, predict_files) end_time = time.time() # save model_pb path to a file f = tf.gfile.GFile(params.model_pb[:-9] + "latest_model_path", 'w') f.write(str(trained_model_path, encoding="utf-8")) print("model_save training time: %.2f s" % (end_time - start_time)) print("--------------predict------------") op.model_predict(trained_model_path, predict_files, params) elif params.mode == "eval": print("--------------predict------------") predict_files = data_load.get_file_list(params.predict_path) op.model_predict(params.model_pb, predict_files, params) else: print("action_type = %s is error !!!" % params.mode)
def main(_): handle_arguments() check_arguments() if FLAGS.clear_existing_model_dir: try: shutil.rmtree(FLAGS.model_dir) except Exception as e: print(e, "at clear_existing_model_dir") else: print("existing model cleaned at %s" % FLAGS.model_dir) if FLAGS.alg_name == "dnn": model = dnn.model_estimator(FLAGS) elif FLAGS.alg_name == "deepfm": model = deepfm.model_estimator(FLAGS) elif FLAGS.alg_name == "din": model = din.model_estimator(FLAGS) elif FLAGS.alg_name == "dinfm": model = dinfm.model_estimator(FLAGS) elif FLAGS.alg_name == "autoint": model = autoint.model_estimator(FLAGS) else: print("ERROR!!! alg_name = %s is not exit!" % FLAGS.alg_name) exit(-1) if FLAGS.task_mode == "train": # model.evaluate(input_fn=lambda: data_load.input_fn(FLAGS.eval_data, FLAGS)) model.train(input_fn=lambda: data_load.input_fn(FLAGS.train_data, FLAGS)) model_op.model_save_pb(FLAGS, model) elif FLAGS.task_mode == "eval": model.evaluate(input_fn=lambda: data_load.input_fn(FALGS.eval_data, FLAGS)) elif FLAGS.task_mode == "infer": # preds = model.predict(input_fn=lambda: data_load.input_fn(FLAGS.eval_data, FLAGS), predict_keys=["item_embedding"]) # f = open(FLAGS.test_data, "rb") # with open(FLAGS.infer_result, "w") as fo: # for line, p in zip(f, preds): # id = line.decode("utf-8").split(" ")[10] # # print(p) # # fo.write("%f\n" % (prob["embedding"])) # emb = ','.join(["%.6f" % f for f in list(p["item_embedding"])]) # fo.write(id + "|" + id + "|" + emb + "\n") pass elif FLAGS.task_mode == "debug": # make some fake data for debugging !!!开发中,暂时不可用 # Flags.debug_data = "data/tfrecord/debug/" # Flags.cont_field_count = 3 # Flags.cate_field_count = 5 # Flags.multi_cate_field_count = 2 # Flags.multi_cate_field_list = [(m1, 3), (m2, 5)] # Flags.total_field_count = 10 # Flags.target_att_1vN_list = [] # Flags.target_att_NvN_list = [] pass else: print("Task_mode Error!")