def main(argv): """Train on examples and export the updated model weights.""" tf_records = argv[1:] logging.info("Training on %s records: %s to %s", len(tf_records), tf_records[0], tf_records[-1]) if FLAGS.dist_train: hvd.init() mllogger = mllog.get_mllogger() mllog.config(filename="train.log") mllog.config(default_namespace="worker1", default_stack_offset=1, default_clear_line=False) with utils.logged_timer("Training"): train(*tf_records) if (not FLAGS.dist_train) or hvd.rank() == 0: if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) epoch = int(os.path.basename(FLAGS.export_path)) mllogger.event(key="save_model", value={"Iteration": epoch}) if FLAGS.freeze: dual_net.freeze_graph(FLAGS.export_path, FLAGS.use_trt, FLAGS.trt_max_batch_size, FLAGS.trt_precision, FLAGS.selfplay_precision)
def run(state, rank, tstate, num_examples, tf_records): # restore train state sess, train_op, data_iter, tf_records_ph = tstate.restore() # calculate steps steps = math.floor(num_examples / FLAGS.train_batch_size) logging.info("Training, steps = %s, batch = %s -> %s examples", steps, FLAGS.train_batch_size, steps * FLAGS.train_batch_size) # init input data logging.info( "[rank %d] [hvd rank %d] Training on %s records: %s to %s work_dir %s num_examples %d", rank, hvd.rank(), len(tf_records), tf_records[0], tf_records[-1], FLAGS.work_dir, num_examples) sess.run(data_iter.initializer, {tf_records_ph: tf_records}) # train for step in range(0, steps): try: sess.run(train_op) except tf.errors.OutOfRangeError: break # export graph if rank == 0: model_path = os.path.join(FLAGS.export_path, state.train_model_name) tf.train.Saver().save(sess, model_path) if FLAGS.freeze: dual_net.freeze_graph(model_path, FLAGS.use_trt, FLAGS.trt_max_batch_size, FLAGS.trt_precision) tstate.save(sess, train_op, data_iter, tf_records_ph)
def main(unused_argv): """Freeze a model to a GraphDef proto.""" if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.model_path) else: dual_net.freeze_graph(FLAGS.model_path, FLAGS.use_trt, FLAGS.trt_max_batch_size, FLAGS.trt_precision)
def main(unused_argv): """Freeze a model to a GraphDef proto.""" if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.model_path) elif FLAGS.trt_batch > 0: dual_net.freeze_graph(FLAGS.model_path, True, FLAGS.trt_batch) else: dual_net.freeze_graph(FLAGS.model_path)
def main(argv): """Train on examples and export the updated model weights.""" tf_records = argv[1:] logging.info("Training on %s records: %s to %s", len(tf_records), tf_records[0], tf_records[-1]) with utils.logged_timer("Training"): train(*tf_records) if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) if FLAGS.freeze: if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.export_path) else: dual_net.freeze_graph(FLAGS.export_path)
def main(unused_argv): """Freeze a model to a GraphDef proto.""" # Use last GPU for freeze os.environ["CUDA_VISIBLE_DEVICES"] = "7" if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.model_path) elif FLAGS.trt_batch > 0: dual_net.freeze_graph(FLAGS.model_path, True, FLAGS.trt_batch) else: dual_net.freeze_graph(FLAGS.model_path) icomm = MPI.Comm.Get_parent() if icomm != MPI.COMM_NULL: icomm.barrier() icomm.Disconnect()
def main(argv): """Train on examples and export the updated model weights.""" tf_records = argv[1:] logging.info("Training on %s records: %s to %s", len(tf_records), tf_records[0], tf_records[-1]) with utils.logged_timer("Training"): estimator = train(*tf_records) if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) estimator.export_saved_model(FLAGS.export_path, serving_input_receiver_fn()) else: estimator.export_saved_model('saved_model', serving_input_receiver_fn()) if FLAGS.freeze: if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.export_path) else: dual_net.freeze_graph(FLAGS.export_path)
def main(argv): """Train on examples and export the updated model weights.""" if FLAGS.dist_train: hvd.init() mll.global_batch_size(FLAGS.train_batch_size) mll.lr_rates(FLAGS.lr_rates) mll.lr_boundaries(FLAGS.lr_boundaries) tf_records = argv[1:] logging.info("Training on %s records: %s to %s", len(tf_records), tf_records[0], tf_records[-1]) with utils.logged_timer("Training"): train(*tf_records) if (not FLAGS.dist_train) or hvd.rank() == 0: if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) if FLAGS.freeze: if FLAGS.use_tpu: dual_net.freeze_graph_tpu(FLAGS.export_path) else: dual_net.freeze_graph(FLAGS.export_path)