def map_fun(context): print(tf.__version__) sys.stdout.flush() tf_context = TFContext(context) job_name = tf_context.get_role_name() index = tf_context.get_index() cluster_json = tf_context.get_tf_cluster() print(cluster_json) sys.stdout.flush() cluster = tf.train.ClusterSpec(cluster=cluster_json) server = tf.train.Server(cluster, job_name=job_name, task_index=index) sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, device_filters=["/job:ps", "/job:worker/task:%d" % index]) if 'ps' == job_name: from time import sleep while True: sleep(1) else: with tf.device( tf.train.replica_device_setter( worker_device='/job:worker/task:' + str(index), cluster=cluster)): record_defaults = [[9], [tf.constant(value=9, dtype=tf.int64)], [9.0], [tf.constant(value=9.0, dtype=tf.float64)], ["9.0"]] dataset = context.flinkStreamDataSet(buffer_size=0) dataset = dataset.map(lambda record: tf.decode_csv( record, record_defaults=record_defaults)) dataset = dataset.batch(3) iterator = dataset.make_one_shot_iterator() input_records = iterator.get_next() global_step = tf.train.get_or_create_global_step() global_step_inc = tf.assign_add(global_step, 1) out_list = [input_records[0], input_records[2], input_records[4]] out = tff_ops.encode_csv(input_list=out_list) is_chief = (index == 0) t = time.time() try: with tf.train.MonitoredTrainingSession( master=server.target, is_chief=is_chief, config=sess_config, checkpoint_dir="./target/tmp/input_output/" + str(t)) as mon_sess: # while not mon_sess.should_stop(): while True: print(index, mon_sess.run([global_step_inc, out])) sys.stdout.flush() # time.sleep(1) except Exception as e: print('traceback.print_exc():') traceback.print_exc() sys.stdout.flush() finally: SummaryWriterCache.clear()
def map_fun(context): tf.compat.v1.disable_v2_behavior() tf_context = TFContext(context) job_name = tf_context.get_role_name() index = tf_context.get_index() cluster_json = tf_context.get_tf_cluster() print(cluster_json) sys.stdout.flush() cluster = tf.compat.v1.train.ClusterSpec(cluster=cluster_json) server = tf.compat.v1.train.Server(cluster, job_name=job_name, task_index=index) sess_config = tf.compat.v1.ConfigProto( allow_soft_placement=True, log_device_placement=False, device_filters=["/job:ps", "/job:worker/task:%d" % index]) if 'ps' == job_name: from time import sleep while True: sleep(1) else: with tf.compat.v1.device( tf.compat.v1.train.replica_device_setter( worker_device='/job:worker/task:' + str(index), cluster=cluster)): global_step = tf.compat.v1.train.get_or_create_global_step() global_step_inc = tf.compat.v1.assign_add(global_step, 1) input_records = [ tf.constant([1, 2, 3]), tf.constant([1.0, 2.0, 3.0]), tf.constant(['1.0', '2.0', '3.0']) ] out = tff_ops.encode_csv(input_list=input_records, field_delim='|') fw = tff_ops.FlinkTFRecordWriter(address=context.toFlink()) w = fw.write([out]) is_chief = (index == 0) t = time.time() try: hooks = [tf.compat.v1.train.StopAtStepHook(last_step=50)] with tf.compat.v1.train.MonitoredTrainingSession( master=server.target, config=sess_config, is_chief=is_chief, checkpoint_dir="./target/tmp/with_output/" + str(t), hooks=hooks) as mon_sess: while not mon_sess.should_stop(): print(index, mon_sess.run([global_step_inc, w])) sys.stdout.flush() time.sleep(1) finally: SummaryWriterCache.clear()