def main():
    train_op = get_model()
    tracing_server = TracingServer()
    with tf.train.MonitoredTrainingSession(hooks=[tracing_server.hook]) as sess:
        for _ in range(NUM_STEPS):
            sess.run(train_op)

    # Save the tracing session
    tracing_server.save_session("session.pickle")

    # Keep the tracing server running beyond training. Remove otherwise.
    tracing_server.join()
Exemplo n.º 2
0
def main():
    estimator = tf.estimator.DNNClassifier(
        hidden_units=[10] * 150,
        feature_columns=[
            tf.feature_column.numeric_column("x", shape=INPUT_SIZE)
        ],
        n_classes=NUM_CLASSES,
    )
    tracing_server = TracingServer()
    estimator.train(input_fn, hooks=[tracing_server.hook])
    estimator.evaluate(input_fn, hooks=[tracing_server.hook])

    # Save the tracing session
    tracing_server.save_session("session.pickle")

    # Keep the tracing server running beyond training. Remove otherwise.
    tracing_server.join()
Exemplo n.º 3
0
def main():
    estimator = tf.estimator.DNNClassifier(
        hidden_units=[10] * 150,
        feature_columns=[
            tf.feature_column.numeric_column("x", shape=INPUT_SIZE)
        ],
        n_classes=NUM_CLASSES,
    )
    tracing_server = TracingServer(
        server_ip=SERVER_IP,
        server_port=SERVER_PORT,
        keep_traces=KEEP_TRACES,
        start_web_server_on_start=START_WEB_SERVER_ON_START)
    estimator.train(input_fn, hooks=[tracing_server.hook])
    estimator.evaluate(input_fn, hooks=[tracing_server.hook])

    # Save the tracing session
    tracing_server.save_session("session.pickle")

    # Keep the tracing server running beyond training. Remove otherwise.
    tracing_server.join()
Exemplo n.º 4
0
def main(_):
    print(_)
    hvd.init()
    train_op = get_model()

    hooks = [
        hvd.BroadcastGlobalVariablesHook(0),
    ]

    if hvd.rank() == 0:
        tracing_server = TracingServer()
        hooks.append(tracing_server.hook)

    with tf.train.MonitoredTrainingSession(hooks=hooks, **get_config()) as sess:
        for _ in range(NUM_STEPS):
            sess.run(train_op)

    if hvd.rank() == 0:
        # Save the tracing session
        tracing_server.save_session("session.pickle")

        # Keep the tracing server running beyond training. Remove otherwise.
        tracing_server.join()
Exemplo n.º 5
0
def main(_):
    hvd.init()
    train_op = get_model()

    hooks = [
        hvd.BroadcastGlobalVariablesHook(0),
    ]

    # Assign a different TCP port to processes colocated on a same node
    server_port = 9999 + hvd.local_rank()
    tracing_server = TracingServer(server_port=server_port, is_horovod=True)
    hooks.append(tracing_server.hook)

    with tf.train.MonitoredTrainingSession(hooks=hooks,
                                           **get_config()) as sess:
        for _ in range(NUM_STEPS):
            sess.run(train_op)

    # Save the tracing session
    tracing_server.save_session("session-{}.pickle".format(hvd.rank()))

    # Keep the tracing server running beyond training. Remove otherwise.
    tracing_server.join()
Exemplo n.º 6
0
Arquivo: train.py Projeto: xz725/tfnet
def main(argv):
    if FLAGS.debug:
        print("Unprocessed flags:", argv)
        tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.debug('-------------------------------------------')
        tf.logging.debug('DEBUG MODE')
        tf.logging.debug('-------------------------------------------')

        tf.logging.debug('Time params:' + str(argshelper.get_time_params()))
        tf.logging.debug('Freq params:' + str(argshelper.get_freq_params()))

    tf.logging.set_verbosity(tf.logging.INFO)
    degrade_fn = lambda x: ds.downsample_by(x, FLAGS.downsample_rate)
    dset = ds.get_dataset(
        FLAGS.trainset,
        path=FLAGS.datapath,
        degrade_fn=degrade_fn,
        epochs=FLAGS.epochs,
        batchsize=FLAGS.batchsize,
        segs_per_sample=FLAGS.batchsize // 4,
    )
    #train_input_fn = lambda: dset().make_one_shot_iterator().get_next()
    train_input_fn = dset
    if FLAGS.testset:
        eval_dset = ds.get_dataset(FLAGS.testset,
                                   path=FLAGS.datapath,
                                   epochs=1,
                                   degrade_fn=degrade_fn,
                                   batchsize=FLAGS.batchsize,
                                   shuffle=False)
        #eval_input_fn = lambda: eval_dset().make_one_shot_iterator().get_next()
        eval_input_fn = eval_dset

    if FLAGS.multigpu:
        config = argshelper.distribute.multi_gpu_config(
            log_step_count_steps=FLAGS.log_step_count_steps)
    else:
        config = tf.estimator.RunConfig(
            log_step_count_steps=FLAGS.log_step_count_steps)

    if FLAGS.save_checkpoints_steps:
        config = config.replace(
            save_checkpoints_steps=FLAGS.save_checkpoints_steps)

    if FLAGS.save_checkpoints_secs:
        config = config.replace(
            save_checkpoints_secs=FLAGS.save_checkpoints_secs)

    if FLAGS.usexla:
        sess_config = tf.ConfigProto()
        sess_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        config = config.replace(session_config=sess_config)

    _summaries = []
    if FLAGS.audio_sample_rate > 0:
        _summaries.append(
            summaries.audio_sample_summary(FLAGS.audio_sample_rate))
    elif FLAGS.audio_sample_rate == 0:
        raise NotImplementedError('Automatic sample rate determination is not'
                                  'yet implemented')
    _summaries.append(summaries.audio_spectrogram_summary())

    if FLAGS.spectral_copies:
        net_config = nets.build_net(
            FLAGS.objective,
            FLAGS.downsample_rate,
            time_params=argshelper.get_time_params(),
            freq_params=argshelper.get_freq_params(),
            window_length=FLAGS.window_length,
            transform=FLAGS.transform,
            fusion_op=FLAGS.fusion_op,
        )
    else:
        net_config = nets.build_net(
            FLAGS.objective,
            time_params=argshelper.get_time_params(),
            freq_params=argshelper.get_freq_params(),
            window_length=FLAGS.window_length,
            transform=FLAGS.transform,
            fusion_op=FLAGS.fusion_op,
        )

    if FLAGS.learning_rate_decay:
        learning_rate = lambda: tf.train.polynomial_decay(
            FLAGS.learning_rate,
            end_learning_rate=1e-6,
            global_step=tf.train.get_global_step(),
            decay_steps=500000,
            power=0.5)
    else:
        learning_rate = lambda: FLAGS.learning_rate

    optimizers = {
        'adam':
        lambda: tf.train.AdamOptimizer(learning_rate=learning_rate()),
        'sgd':
        lambda: tf.train.GradientDescentOptimizer(learning_rate=learning_rate(
        )),
    }

    print('------------------------------------------')
    print('Save steps: {}, ({}s)'.format(config.save_checkpoints_steps,
                                         config.save_checkpoints_secs))
    print('------------------------------------------')

    tfnet_est = TFNetEstimator(**net_config,
                               model_dir=FLAGS.model_dir,
                               add_summaries=_summaries,
                               optimizer=optimizers[FLAGS.optimizer],
                               weight_decay=FLAGS.weight_decay,
                               config=config)

    hooks = []
    if FLAGS.profile:
        hooks += [
            tf.train.ProfilerHook(output_dir=FLAGS.model_dir,
                                  save_steps=500,
                                  show_memory=False),
        ]

    if FLAGS.enable_tracer:
        try:
            from tftracer import TracingServer
            tracing_server = TracingServer(server_port=8888)
            hooks += [tracing_server.hook]
        except ImportError:
            tf.logging.warn("tensorflow-tracer not available. Will not be "
                            "enabled")

    if FLAGS.testset:
        #eval_summary_hook = tf.train.SummarySaverHook(
        #    save_steps=1,
        #    summary_op=tf.summary.merge_all('audio_samples'))

        train_spec = tf.estimator.TrainSpec(train_input_fn, hooks=hooks)
        eval_spec = tf.estimator.EvalSpec(
            eval_input_fn,
            steps=None,
            #hooks=[eval_summary_hook]
        )
        while True:
            try:
                tf.estimator.train_and_evaluate(estimator=tfnet_est,
                                                train_spec=train_spec,
                                                eval_spec=eval_spec)
            except tf.estimator.NanLossDuringTrainingError:
                tf.logging.warn("NaN loss encountered. Attempting to continue")
                continue
            break
    else:
        tfnet_est.train(input_fn=train_input_fn, hooks=hooks)
#! /usr/bin/env python -u
# coding=utf-8
import time

from tftracer import TracingServer

__author__ = 'Sayed Hadi Hashemi'
if __name__ == '__main__':
    server = TracingServer()
    server.load_session("session.pickle")

    # TODO(xldrx): Slow Server Workaround
    time.sleep(5)

    server.join()