trainable=False,
                                  collections=[])
    assign_filename_op = filename_tensor.assign(original_assets_filename)

    # 定义模型导出配置
    if os.path.exists(FLAGS.export_dir):
        print("The export path has existed, try to delete it...")
        shutil.rmtree(FLAGS.export_dir)
        print("The export path has been deleted.")
    model_export_spec = model_exporter.ModelExportSpec(
        export_dir=FLAGS.export_dir,
        input_tensors={'x': x},
        output_tensors={'y': y},
        assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
        legacy_init_op=tf.group(assign_filename_op))

    return dist_base.ModelFnHandler(global_step=global_step,
                                    model_export_spec=model_export_spec)


def train_fn(session, global_step):
    """训练模型
  """
    # 该线性模型训练时啥也不做。
    return True


if __name__ == "__main__":
    distTfRunner = dist_base.DistTensorflowRunner(model_fn=model_fn)
    distTfRunner.run(train_fn)
Exemplo n.º 2
0
        本函数必须接收两个参数:
          - scafford: tf.train.Scaffold 对象;
          - sess: tf.Session 对象。
        """
        saver.restore(sess, checkpoint_path)

    return init_from_checkpoint


def after_train_hook(session):
    """模型训练操作。

    TaaS 在整个模型训练结束之后会调用该函数来进行相关的善后处理。
    这些善后处理需要您基于业务需要来提供,例如模型测试等。
    
    参数:
    - `session`:tf.Session 对象。
    """
    pass


if __name__ == '__main__':
    # 定义分布式 TensorFlow 运行器 DistTensorflowRunner 对象。
    distTfRunner = dist_base.DistTensorflowRunner(
        model_fn=model_fn,
        after_train_hook=after_train_hook,
        gen_init_fn=gen_init_fn)
    # 调用 DistTensorflowRunner 对象的 run 方法执行分布式模型训练,需要传递每轮模型训练的
    # 操作实现函数 train_fn。
    distTfRunner.run(train_fn)
Exemplo n.º 3
0
        output_tensors={"infer": _infer})

    # 定义模型评测(准确率)的计算方法
    model_metric_ops = {
        "rmse": rmse_evalute_fn
    }
    
    return dist_base.ModelFnHandler(
        global_step=_global_step,
        optimizer=optimizer, 
        model_metric_ops=model_metric_ops,
        model_export_spec=model_export_spec,
        summary_op=None)
    
def train_fn(session, num_global_step):
    global _train_op, _infer, _user_batch, _item_batch, _rate_batch, _rmse, _local_step, _cost
    
    users, items, rates = next(_iter_train)            
    session.run(_train_op, feed_dict={_user_batch: users, _item_batch: items, _rate_batch: rates})
            
    if _local_step % 2000 == 0:
        rmse, infer, cost = session.run([_rmse, _infer, _cost], feed_dict={_user_batch: _test["user"], _item_batch: _test["item"], _rate_batch: _test["rate"]})
        print("Eval RMSE at round {} is: {}".format(num_global_step, rmse))
    
    _local_step += 1        
    return False

if __name__ == '__main__':
    distTfRunner = dist_base.DistTensorflowRunner(model_fn = model_fn, gen_init_fn=None)
    distTfRunner.run(train_fn)