Beispiel #1
0
def main(unused_argv):
    """Freeze a model to a GraphDef proto."""
    if FLAGS.use_tpu:
        dual_net.freeze_graph_tpu(FLAGS.model_path)
    else:
        dual_net.freeze_graph(FLAGS.model_path, FLAGS.use_trt,
                              FLAGS.trt_max_batch_size, FLAGS.trt_precision)
Beispiel #2
0
def main(unused_argv):
    """Freeze a model to a GraphDef proto."""
    if FLAGS.use_tpu:
        dual_net.freeze_graph_tpu(FLAGS.model_path)
    elif FLAGS.trt_batch > 0:
        dual_net.freeze_graph(FLAGS.model_path, True, FLAGS.trt_batch)
    else:
        dual_net.freeze_graph(FLAGS.model_path)
Beispiel #3
0
def main(argv):
    """Train on examples and export the updated model weights."""
    tf_records = argv[1:]
    logging.info("Training on %s records: %s to %s",
                 len(tf_records), tf_records[0], tf_records[-1])
    with utils.logged_timer("Training"):
        train(*tf_records)
    if FLAGS.export_path:
        dual_net.export_model(FLAGS.export_path)
    if FLAGS.freeze:
        if FLAGS.use_tpu:
            dual_net.freeze_graph_tpu(FLAGS.export_path)
        else:
            dual_net.freeze_graph(FLAGS.export_path)
Beispiel #4
0
def main(unused_argv):
    """Freeze a model to a GraphDef proto."""
    # Use last GPU for freeze
    os.environ["CUDA_VISIBLE_DEVICES"] = "7"
    if FLAGS.use_tpu:
        dual_net.freeze_graph_tpu(FLAGS.model_path)
    elif FLAGS.trt_batch > 0:
        dual_net.freeze_graph(FLAGS.model_path, True, FLAGS.trt_batch)
    else:
        dual_net.freeze_graph(FLAGS.model_path)

    icomm = MPI.Comm.Get_parent()
    if icomm != MPI.COMM_NULL:
        icomm.barrier()
        icomm.Disconnect()
Beispiel #5
0
def main(argv):
    """Train on examples and export the updated model weights."""
    tf_records = argv[1:]
    logging.info("Training on %s records: %s to %s", len(tf_records),
                 tf_records[0], tf_records[-1])
    with utils.logged_timer("Training"):
        estimator = train(*tf_records)
    if FLAGS.export_path:
        dual_net.export_model(FLAGS.export_path)
        estimator.export_saved_model(FLAGS.export_path,
                                     serving_input_receiver_fn())
    else:
        estimator.export_saved_model('saved_model',
                                     serving_input_receiver_fn())
    if FLAGS.freeze:
        if FLAGS.use_tpu:
            dual_net.freeze_graph_tpu(FLAGS.export_path)
        else:
            dual_net.freeze_graph(FLAGS.export_path)
Beispiel #6
0
def main(argv):
    """Train on examples and export the updated model weights."""
    if FLAGS.dist_train:
        hvd.init()
    mll.global_batch_size(FLAGS.train_batch_size)
    mll.lr_rates(FLAGS.lr_rates)
    mll.lr_boundaries(FLAGS.lr_boundaries)
    tf_records = argv[1:]
    logging.info("Training on %s records: %s to %s", len(tf_records),
                 tf_records[0], tf_records[-1])
    with utils.logged_timer("Training"):
        train(*tf_records)

    if (not FLAGS.dist_train) or hvd.rank() == 0:
        if FLAGS.export_path:
            dual_net.export_model(FLAGS.export_path)
        if FLAGS.freeze:
            if FLAGS.use_tpu:
                dual_net.freeze_graph_tpu(FLAGS.export_path)
            else:
                dual_net.freeze_graph(FLAGS.export_path)