예제 #1
0
def main(argv):
    """Train on examples and export the updated model weights."""
    tf_records = argv[1:]
    logging.info("Training on %s records: %s to %s", len(tf_records),
                 tf_records[0], tf_records[-1])

    if FLAGS.dist_train:
        hvd.init()

    mllogger = mllog.get_mllogger()
    mllog.config(filename="train.log")

    mllog.config(default_namespace="worker1",
                 default_stack_offset=1,
                 default_clear_line=False)

    with utils.logged_timer("Training"):
        train(*tf_records)
    if (not FLAGS.dist_train) or hvd.rank() == 0:
        if FLAGS.export_path:
            dual_net.export_model(FLAGS.export_path)
            epoch = int(os.path.basename(FLAGS.export_path))
            mllogger.event(key="save_model", value={"Iteration": epoch})
        if FLAGS.freeze:
            dual_net.freeze_graph(FLAGS.export_path, FLAGS.use_trt,
                                  FLAGS.trt_max_batch_size,
                                  FLAGS.trt_precision,
                                  FLAGS.selfplay_precision)
예제 #2
0
def run(state, rank, tstate, num_examples, tf_records):
    # restore train state
    sess, train_op, data_iter, tf_records_ph = tstate.restore()

    # calculate steps
    steps = math.floor(num_examples / FLAGS.train_batch_size)
    logging.info("Training, steps = %s, batch = %s -> %s examples", steps,
                 FLAGS.train_batch_size, steps * FLAGS.train_batch_size)

    # init input data
    logging.info(
        "[rank %d] [hvd rank %d] Training on %s records: %s to %s work_dir %s num_examples %d",
        rank, hvd.rank(), len(tf_records), tf_records[0], tf_records[-1],
        FLAGS.work_dir, num_examples)
    sess.run(data_iter.initializer, {tf_records_ph: tf_records})

    # train
    for step in range(0, steps):
        try:
            sess.run(train_op)
        except tf.errors.OutOfRangeError:
            break

    # export graph
    if rank == 0:
        model_path = os.path.join(FLAGS.export_path, state.train_model_name)
        tf.train.Saver().save(sess, model_path)
        if FLAGS.freeze:
            dual_net.freeze_graph(model_path, FLAGS.use_trt,
                                  FLAGS.trt_max_batch_size,
                                  FLAGS.trt_precision)

    tstate.save(sess, train_op, data_iter, tf_records_ph)
예제 #3
0
def main(unused_argv):
    """Freeze a model to a GraphDef proto."""
    if FLAGS.use_tpu:
        dual_net.freeze_graph_tpu(FLAGS.model_path)
    else:
        dual_net.freeze_graph(FLAGS.model_path, FLAGS.use_trt,
                              FLAGS.trt_max_batch_size, FLAGS.trt_precision)
예제 #4
0
def main(unused_argv):
    """Freeze a model to a GraphDef proto."""
    if FLAGS.use_tpu:
        dual_net.freeze_graph_tpu(FLAGS.model_path)
    elif FLAGS.trt_batch > 0:
        dual_net.freeze_graph(FLAGS.model_path, True, FLAGS.trt_batch)
    else:
        dual_net.freeze_graph(FLAGS.model_path)
예제 #5
0
파일: train.py 프로젝트: zhiwuya/minigo
def main(argv):
    """Train on examples and export the updated model weights."""
    tf_records = argv[1:]
    logging.info("Training on %s records: %s to %s",
                 len(tf_records), tf_records[0], tf_records[-1])
    with utils.logged_timer("Training"):
        train(*tf_records)
    if FLAGS.export_path:
        dual_net.export_model(FLAGS.export_path)
    if FLAGS.freeze:
        if FLAGS.use_tpu:
            dual_net.freeze_graph_tpu(FLAGS.export_path)
        else:
            dual_net.freeze_graph(FLAGS.export_path)
예제 #6
0
def main(unused_argv):
    """Freeze a model to a GraphDef proto."""
    # Use last GPU for freeze
    os.environ["CUDA_VISIBLE_DEVICES"] = "7"
    if FLAGS.use_tpu:
        dual_net.freeze_graph_tpu(FLAGS.model_path)
    elif FLAGS.trt_batch > 0:
        dual_net.freeze_graph(FLAGS.model_path, True, FLAGS.trt_batch)
    else:
        dual_net.freeze_graph(FLAGS.model_path)

    icomm = MPI.Comm.Get_parent()
    if icomm != MPI.COMM_NULL:
        icomm.barrier()
        icomm.Disconnect()
예제 #7
0
def main(argv):
    """Train on examples and export the updated model weights."""
    tf_records = argv[1:]
    logging.info("Training on %s records: %s to %s", len(tf_records),
                 tf_records[0], tf_records[-1])
    with utils.logged_timer("Training"):
        estimator = train(*tf_records)
    if FLAGS.export_path:
        dual_net.export_model(FLAGS.export_path)
        estimator.export_saved_model(FLAGS.export_path,
                                     serving_input_receiver_fn())
    else:
        estimator.export_saved_model('saved_model',
                                     serving_input_receiver_fn())
    if FLAGS.freeze:
        if FLAGS.use_tpu:
            dual_net.freeze_graph_tpu(FLAGS.export_path)
        else:
            dual_net.freeze_graph(FLAGS.export_path)
예제 #8
0
def main(argv):
    """Train on examples and export the updated model weights."""
    if FLAGS.dist_train:
        hvd.init()
    mll.global_batch_size(FLAGS.train_batch_size)
    mll.lr_rates(FLAGS.lr_rates)
    mll.lr_boundaries(FLAGS.lr_boundaries)
    tf_records = argv[1:]
    logging.info("Training on %s records: %s to %s", len(tf_records),
                 tf_records[0], tf_records[-1])
    with utils.logged_timer("Training"):
        train(*tf_records)

    if (not FLAGS.dist_train) or hvd.rank() == 0:
        if FLAGS.export_path:
            dual_net.export_model(FLAGS.export_path)
        if FLAGS.freeze:
            if FLAGS.use_tpu:
                dual_net.freeze_graph_tpu(FLAGS.export_path)
            else:
                dual_net.freeze_graph(FLAGS.export_path)