def map_fun(context):
    print(tf.__version__)
    sys.stdout.flush()
    tf_context = TFContext(context)
    job_name = tf_context.get_role_name()
    index = tf_context.get_index()
    cluster_json = tf_context.get_tf_cluster()
    print(cluster_json)
    sys.stdout.flush()
    cluster = tf.train.ClusterSpec(cluster=cluster_json)
    server = tf.train.Server(cluster, job_name=job_name, task_index=index)
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        device_filters=["/job:ps", "/job:worker/task:%d" % index])
    if 'ps' == job_name:
        from time import sleep
        while True:
            sleep(1)
    else:
        with tf.device(
                tf.train.replica_device_setter(
                    worker_device='/job:worker/task:' + str(index),
                    cluster=cluster)):
            record_defaults = [[9], [tf.constant(value=9, dtype=tf.int64)],
                               [9.0],
                               [tf.constant(value=9.0, dtype=tf.float64)],
                               ["9.0"]]
            dataset = context.flinkStreamDataSet(buffer_size=0)
            dataset = dataset.map(lambda record: tf.decode_csv(
                record, record_defaults=record_defaults))
            dataset = dataset.batch(3)
            iterator = dataset.make_one_shot_iterator()
            input_records = iterator.get_next()

            global_step = tf.train.get_or_create_global_step()
            global_step_inc = tf.assign_add(global_step, 1)
            out_list = [input_records[0], input_records[2], input_records[4]]
            out = tff_ops.encode_csv(input_list=out_list)
            is_chief = (index == 0)
            t = time.time()
            try:
                with tf.train.MonitoredTrainingSession(
                        master=server.target,
                        is_chief=is_chief,
                        config=sess_config,
                        checkpoint_dir="./target/tmp/input_output/" +
                        str(t)) as mon_sess:
                    # while not mon_sess.should_stop():
                    while True:
                        print(index, mon_sess.run([global_step_inc, out]))
                        sys.stdout.flush()
                        # time.sleep(1)
            except Exception as e:
                print('traceback.print_exc():')
                traceback.print_exc()
                sys.stdout.flush()
            finally:
                SummaryWriterCache.clear()
Example #2
0
def map_fun(context):
    tf.compat.v1.disable_v2_behavior()
    tf_context = TFContext(context)
    job_name = tf_context.get_role_name()
    index = tf_context.get_index()
    cluster_json = tf_context.get_tf_cluster()
    print(cluster_json)
    sys.stdout.flush()
    cluster = tf.compat.v1.train.ClusterSpec(cluster=cluster_json)
    server = tf.compat.v1.train.Server(cluster,
                                       job_name=job_name,
                                       task_index=index)
    sess_config = tf.compat.v1.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        device_filters=["/job:ps", "/job:worker/task:%d" % index])
    if 'ps' == job_name:
        from time import sleep
        while True:
            sleep(1)
    else:
        with tf.compat.v1.device(
                tf.compat.v1.train.replica_device_setter(
                    worker_device='/job:worker/task:' + str(index),
                    cluster=cluster)):

            global_step = tf.compat.v1.train.get_or_create_global_step()
            global_step_inc = tf.compat.v1.assign_add(global_step, 1)
            input_records = [
                tf.constant([1, 2, 3]),
                tf.constant([1.0, 2.0, 3.0]),
                tf.constant(['1.0', '2.0', '3.0'])
            ]
            out = tff_ops.encode_csv(input_list=input_records, field_delim='|')
            fw = tff_ops.FlinkTFRecordWriter(address=context.toFlink())
            w = fw.write([out])
            is_chief = (index == 0)
            t = time.time()
            try:
                hooks = [tf.compat.v1.train.StopAtStepHook(last_step=50)]
                with tf.compat.v1.train.MonitoredTrainingSession(
                        master=server.target,
                        config=sess_config,
                        is_chief=is_chief,
                        checkpoint_dir="./target/tmp/with_output/" + str(t),
                        hooks=hooks) as mon_sess:
                    while not mon_sess.should_stop():
                        print(index, mon_sess.run([global_step_inc, w]))
                        sys.stdout.flush()
                        time.sleep(1)
            finally:
                SummaryWriterCache.clear()