Exemple #1
0
def main(_):
    train_dir = setup_experiment.setup_experiment("worker%d_logs" % FLAGS.task)

    run_fn = truncated_training.worker_run

    while True:
        logging.info("Starting a new graph reset iteration")
        g = tf.Graph()
        with g.as_default():
            np_global_step = 0

            tf.set_random_seed(FLAGS.tf_seed)
            logging.info("building graph... with state from %d",
                         np_global_step)
            stime = time.time()
            graph_dict = truncated_training.build_graph(
                np_global_step=np_global_step)
            logging.info("done building graph... (took %f sec)",
                         time.time() - stime)

            # perform a series of unrolls
            run_fn(train_dir, graph_dict)

            ## reset the graph
            if FLAGS.master:
                logging.info("running tf.session.Reset")
                config = tf.ConfigProto(
                    device_filters=["/job:worker/replica:%d" % FLAGS.task])
                tf.Session.reset(FLAGS.master, config=config)
            if FLAGS.master == "local":
                logging.info("running tf.session.Reset")
                config = tf.ConfigProto()
                tf.Session.reset(FLAGS.master, config=config)
def main(_):
    train_dir = setup_experiment.setup_experiment("chief_logs")
    graph_dict = truncated_training.build_graph()
    truncated_training.chief_run(train_dir, graph_dict)