def freeze_graph(model_path, use_trt=False, trt_max_batch_size=8, trt_precision='fp32', selfplay_precision='fp32'): output_names = ['policy_output', 'value_output'] n = DualNetwork(model_path) out_graph = tf.compat.v1.graph_util.convert_variables_to_constants( n.sess, n.sess.graph.as_graph_def(), output_names) if use_trt: import tensorflow.contrib.tensorrt as trt out_graph = trt.create_inference_graph( input_graph_def=out_graph, outputs=output_names, max_batch_size=trt_max_batch_size, max_workspace_size_bytes=1 << 29, precision_mode=trt_precision) metadata = make_model_metadata({ 'engine': 'tf', 'use_trt': bool(use_trt), }) if (selfplay_precision == 'fp32'): minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo') else: with tf.io.gfile.GFile(model_path + '.pb', 'wb') as write_f: write_f.write(out_graph.SerializeToString())
def freeze_graph(model_path, use_trt=False, trt_max_batch_size=8, trt_precision='fp32'): output_names = ['policy_output', 'value_output'] n = DualNetwork(model_path) out_graph = tf.graph_util.convert_variables_to_constants( n.sess, n.sess.graph.as_graph_def(), output_names) if use_trt: import tensorflow.contrib.tensorrt as trt out_graph = trt.create_inference_graph( input_graph_def=out_graph, outputs=output_names, max_batch_size=trt_max_batch_size, max_workspace_size_bytes=1 << 29, precision_mode=trt_precision) metadata = make_model_metadata({ 'engine': 'tf', 'use_trt': bool(use_trt), }) minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')
def main(unused_argv): metadata, model_bytes = minigo_model.read_model(FLAGS.src_path) assert metadata['input_layout'] == 'nchw' graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(model_bytes) if FLAGS.mode == 'first_and_last': transpose_first_and_last_convs(graph_def, 'pos_tensor', 'Reshape') elif FLAGS.mode == 'all': transpose_all_convs(graph_def) else: raise ValueError('Unexpected transpose mode.') minigo_model.write_graph_def(graph_def, metadata, FLAGS.dst_path)
def freeze_graph_tpu(model_path): """Custom freeze_graph implementation for Cloud TPU.""" assert model_path assert FLAGS.tpu_name if FLAGS.tpu_name.startswith('grpc://'): tpu_grpc_url = FLAGS.tpu_name else: tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=None, project=None) tpu_grpc_url = tpu_cluster_resolver.get_master() sess = tf.Session(tpu_grpc_url) output_names = [] with sess.graph.as_default(): # Replicate the inference function for each TPU core. replicated_features = [] feature_type = tf.bool if FLAGS.bool_features else tf.float32 for i in range(FLAGS.num_tpu_cores): name = 'pos_tensor_%d' % i features = tf.placeholder(feature_type, [None], name=name) replicated_features.append((features, )) outputs = contrib_tpu.replicate(tpu_model_inference_fn, replicated_features) # The replicate op assigns names like output_0_shard_0 to the output # names. Give them human readable names. for i, (policy_output, value_output, _) in enumerate(outputs): policy_name = 'policy_output_%d' % i value_name = 'value_output_%d' % i output_names.extend([policy_name, value_name]) tf.identity(policy_output, policy_name) tf.identity(value_output, value_name) tf.train.Saver().restore(sess, model_path) out_graph = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), output_names) metadata = make_model_metadata({ 'engine': 'tpu', 'num_replicas': FLAGS.num_tpu_cores, }) minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')
def freeze_graph(model_path, use_trt=False, trt_max_batch_size=8, trt_precision='fp32'): output_names = ['policy_output', 'value_output'] n = DualNetwork(model_path) out_graph = tf.graph_util.convert_variables_to_constants( n.sess, n.sess.graph.as_graph_def(), output_names) # eval is always fp32, so let's store a eval copy before we trt. metadata = make_model_metadata({ 'engine': 'tf', 'use_trt': False, }) minigo_model.write_graph_def(out_graph, metadata, model_path + '.evalfp32minigo') if use_trt: from tensorflow.python.compiler.tensorrt import trt_convert as trt converter = trt.TrtGraphConverter(input_graph_def=out_graph, nodes_blacklist=output_names, max_batch_size=trt_max_batch_size, max_workspace_size_bytes=1 << 29, precision_mode=trt_precision) out_graph = converter.convert() metadata = make_model_metadata({ 'engine': 'tf', 'use_trt': bool(use_trt), }) # double buffer model write minigo_model.write_graph_def(out_graph, metadata, model_path + '.stagedmodel') minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')