trainable=False, collections=[]) assign_filename_op = filename_tensor.assign(original_assets_filename) # 定义模型导出配置 if os.path.exists(FLAGS.export_dir): print("The export path has existed, try to delete it...") shutil.rmtree(FLAGS.export_dir) print("The export path has been deleted.") model_export_spec = model_exporter.ModelExportSpec( export_dir=FLAGS.export_dir, input_tensors={'x': x}, output_tensors={'y': y}, assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), legacy_init_op=tf.group(assign_filename_op)) return dist_base.ModelFnHandler(global_step=global_step, model_export_spec=model_export_spec) def train_fn(session, global_step): """训练模型 """ # 该线性模型训练时啥也不做。 return True if __name__ == "__main__": distTfRunner = dist_base.DistTensorflowRunner(model_fn=model_fn) distTfRunner.run(train_fn)
本函数必须接收两个参数: - scafford: tf.train.Scaffold 对象; - sess: tf.Session 对象。 """ saver.restore(sess, checkpoint_path) return init_from_checkpoint def after_train_hook(session): """模型训练操作。 TaaS 在整个模型训练结束之后会调用该函数来进行相关的善后处理。 这些善后处理需要您基于业务需要来提供,例如模型测试等。 参数: - `session`:tf.Session 对象。 """ pass if __name__ == '__main__': # 定义分布式 TensorFlow 运行器 DistTensorflowRunner 对象。 distTfRunner = dist_base.DistTensorflowRunner( model_fn=model_fn, after_train_hook=after_train_hook, gen_init_fn=gen_init_fn) # 调用 DistTensorflowRunner 对象的 run 方法执行分布式模型训练,需要传递每轮模型训练的 # 操作实现函数 train_fn。 distTfRunner.run(train_fn)
output_tensors={"infer": _infer}) # 定义模型评测(准确率)的计算方法 model_metric_ops = { "rmse": rmse_evalute_fn } return dist_base.ModelFnHandler( global_step=_global_step, optimizer=optimizer, model_metric_ops=model_metric_ops, model_export_spec=model_export_spec, summary_op=None) def train_fn(session, num_global_step): global _train_op, _infer, _user_batch, _item_batch, _rate_batch, _rmse, _local_step, _cost users, items, rates = next(_iter_train) session.run(_train_op, feed_dict={_user_batch: users, _item_batch: items, _rate_batch: rates}) if _local_step % 2000 == 0: rmse, infer, cost = session.run([_rmse, _infer, _cost], feed_dict={_user_batch: _test["user"], _item_batch: _test["item"], _rate_batch: _test["rate"]}) print("Eval RMSE at round {} is: {}".format(num_global_step, rmse)) _local_step += 1 return False if __name__ == '__main__': distTfRunner = dist_base.DistTensorflowRunner(model_fn = model_fn, gen_init_fn=None) distTfRunner.run(train_fn)