def main():
    train_op = get_model()
    tracing_server = TracingServer()
    with tf.train.MonitoredTrainingSession(hooks=[tracing_server.hook]) as sess:
        for _ in range(NUM_STEPS):
            sess.run(train_op)

    # Save the tracing session
    tracing_server.save_session("session.pickle")

    # Keep the tracing server running beyond training. Remove otherwise.
    tracing_server.join()
예제 #2
0
def main():
    estimator = tf.estimator.DNNClassifier(
        hidden_units=[10] * 150,
        feature_columns=[
            tf.feature_column.numeric_column("x", shape=INPUT_SIZE)
        ],
        n_classes=NUM_CLASSES,
    )
    tracing_server = TracingServer()
    estimator.train(input_fn, hooks=[tracing_server.hook])
    estimator.evaluate(input_fn, hooks=[tracing_server.hook])

    # Save the tracing session
    tracing_server.save_session("session.pickle")

    # Keep the tracing server running beyond training. Remove otherwise.
    tracing_server.join()
예제 #3
0
def main():
    estimator = tf.estimator.DNNClassifier(
        hidden_units=[10] * 150,
        feature_columns=[
            tf.feature_column.numeric_column("x", shape=INPUT_SIZE)
        ],
        n_classes=NUM_CLASSES,
    )
    tracing_server = TracingServer(
        server_ip=SERVER_IP,
        server_port=SERVER_PORT,
        keep_traces=KEEP_TRACES,
        start_web_server_on_start=START_WEB_SERVER_ON_START)
    estimator.train(input_fn, hooks=[tracing_server.hook])
    estimator.evaluate(input_fn, hooks=[tracing_server.hook])

    # Save the tracing session
    tracing_server.save_session("session.pickle")

    # Keep the tracing server running beyond training. Remove otherwise.
    tracing_server.join()
예제 #4
0
def main(_):
    print(_)
    hvd.init()
    train_op = get_model()

    hooks = [
        hvd.BroadcastGlobalVariablesHook(0),
    ]

    if hvd.rank() == 0:
        tracing_server = TracingServer()
        hooks.append(tracing_server.hook)

    with tf.train.MonitoredTrainingSession(hooks=hooks, **get_config()) as sess:
        for _ in range(NUM_STEPS):
            sess.run(train_op)

    if hvd.rank() == 0:
        # Save the tracing session
        tracing_server.save_session("session.pickle")

        # Keep the tracing server running beyond training. Remove otherwise.
        tracing_server.join()
예제 #5
0
def main(_):
    hvd.init()
    train_op = get_model()

    hooks = [
        hvd.BroadcastGlobalVariablesHook(0),
    ]

    # Assign a different TCP port to processes colocated on a same node
    server_port = 9999 + hvd.local_rank()
    tracing_server = TracingServer(server_port=server_port, is_horovod=True)
    hooks.append(tracing_server.hook)

    with tf.train.MonitoredTrainingSession(hooks=hooks,
                                           **get_config()) as sess:
        for _ in range(NUM_STEPS):
            sess.run(train_op)

    # Save the tracing session
    tracing_server.save_session("session-{}.pickle".format(hvd.rank()))

    # Keep the tracing server running beyond training. Remove otherwise.
    tracing_server.join()
#! /usr/bin/env python -u
# coding=utf-8
import time

from tftracer import TracingServer

__author__ = 'Sayed Hadi Hashemi'
if __name__ == '__main__':
    server = TracingServer()
    server.load_session("session.pickle")

    # TODO(xldrx): Slow Server Workaround
    time.sleep(5)

    server.join()