def main(unused_argv): """Freeze a model to a GraphDef proto.""" if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.model_path) else: dual_net.freeze_graph(FLAGS.model_path, FLAGS.use_trt, FLAGS.trt_max_batch_size, FLAGS.trt_precision)
def main(unused_argv): """Freeze a model to a GraphDef proto.""" if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.model_path) elif FLAGS.trt_batch > 0: dual_net.freeze_graph(FLAGS.model_path, True, FLAGS.trt_batch) else: dual_net.freeze_graph(FLAGS.model_path)
def main(argv): """Train on examples and export the updated model weights.""" tf_records = argv[1:] logging.info("Training on %s records: %s to %s", len(tf_records), tf_records[0], tf_records[-1]) with utils.logged_timer("Training"): train(*tf_records) if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) if FLAGS.freeze: if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.export_path) else: dual_net.freeze_graph(FLAGS.export_path)
def main(unused_argv): """Freeze a model to a GraphDef proto.""" # Use last GPU for freeze os.environ["CUDA_VISIBLE_DEVICES"] = "7" if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.model_path) elif FLAGS.trt_batch > 0: dual_net.freeze_graph(FLAGS.model_path, True, FLAGS.trt_batch) else: dual_net.freeze_graph(FLAGS.model_path) icomm = MPI.Comm.Get_parent() if icomm != MPI.COMM_NULL: icomm.barrier() icomm.Disconnect()
def main(argv): """Train on examples and export the updated model weights.""" tf_records = argv[1:] logging.info("Training on %s records: %s to %s", len(tf_records), tf_records[0], tf_records[-1]) with utils.logged_timer("Training"): estimator = train(*tf_records) if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) estimator.export_saved_model(FLAGS.export_path, serving_input_receiver_fn()) else: estimator.export_saved_model('saved_model', serving_input_receiver_fn()) if FLAGS.freeze: if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.export_path) else: dual_net.freeze_graph(FLAGS.export_path)
def main(argv): """Train on examples and export the updated model weights.""" if FLAGS.dist_train: hvd.init() mll.global_batch_size(FLAGS.train_batch_size) mll.lr_rates(FLAGS.lr_rates) mll.lr_boundaries(FLAGS.lr_boundaries) tf_records = argv[1:] logging.info("Training on %s records: %s to %s", len(tf_records), tf_records[0], tf_records[-1]) with utils.logged_timer("Training"): train(*tf_records) if (not FLAGS.dist_train) or hvd.rank() == 0: if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) if FLAGS.freeze: if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.export_path) else: dual_net.freeze_graph(FLAGS.export_path)