def map_func(context): tf_context = TFContext(context) job_name = tf_context.get_role_name() index = tf_context.get_index() cluster_json = tf_context.get_tf_cluster() print (cluster_json) sys.stdout.flush() ckpt = tf_context.get_property("checkpoint_dir") cluster = tf.train.ClusterSpec(cluster=cluster_json) server = tf.train.Server(cluster, job_name=job_name, task_index=index) sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, device_filters=["/job:ps", "/job:worker/task:%d" % index]) t = time.time() if 'ps' == job_name: from time import sleep while True: sleep(1) else: with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:' + str(index), cluster=cluster)): train_ops = build_graph() try: hooks = [tf.train.StopAtStepHook(last_step=50)] with tf.train.MonitoredTrainingSession(master=server.target, config=sess_config, checkpoint_dir=ckpt, hooks=hooks, save_summaries_steps=1) as mon_sess: while not mon_sess.should_stop(): print (mon_sess.run(train_ops, feed_dict={a: [1.0, 2.0, 3.0]})) sys.stdout.flush() time.sleep(1) finally: SummaryWriterCache.clear()
def flink_stream_train(context): tf_context = TFContext(context) job_name = tf_context.get_role_name() index = tf_context.get_index() cluster_json = tf_context.get_tf_cluster() export_model_path = tf_context.get_property("model_save_path") train(cluster_json, job_name, index, export_model_path, tf_context.flink_stream_dataset())