Esempio n. 1
0
def freeze_graph(model_path,
                 use_trt=False,
                 trt_max_batch_size=8,
                 trt_precision='fp32',
                 selfplay_precision='fp32'):
    output_names = ['policy_output', 'value_output']

    n = DualNetwork(model_path)
    out_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
        n.sess, n.sess.graph.as_graph_def(), output_names)

    if use_trt:
        import tensorflow.contrib.tensorrt as trt
        out_graph = trt.create_inference_graph(
            input_graph_def=out_graph,
            outputs=output_names,
            max_batch_size=trt_max_batch_size,
            max_workspace_size_bytes=1 << 29,
            precision_mode=trt_precision)

    metadata = make_model_metadata({
        'engine': 'tf',
        'use_trt': bool(use_trt),
    })

    if (selfplay_precision == 'fp32'):
        minigo_model.write_graph_def(out_graph, metadata,
                                     model_path + '.minigo')
    else:
        with tf.io.gfile.GFile(model_path + '.pb', 'wb') as write_f:
            write_f.write(out_graph.SerializeToString())
Esempio n. 2
0
def freeze_graph(model_path,
                 use_trt=False,
                 trt_max_batch_size=8,
                 trt_precision='fp32'):
    output_names = ['policy_output', 'value_output']

    n = DualNetwork(model_path)
    out_graph = tf.graph_util.convert_variables_to_constants(
        n.sess, n.sess.graph.as_graph_def(), output_names)

    if use_trt:
        import tensorflow.contrib.tensorrt as trt
        out_graph = trt.create_inference_graph(
            input_graph_def=out_graph,
            outputs=output_names,
            max_batch_size=trt_max_batch_size,
            max_workspace_size_bytes=1 << 29,
            precision_mode=trt_precision)

    metadata = make_model_metadata({
        'engine': 'tf',
        'use_trt': bool(use_trt),
    })

    minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')
Esempio n. 3
0
def main(unused_argv):
    metadata, model_bytes = minigo_model.read_model(FLAGS.src_path)
    assert metadata['input_layout'] == 'nchw'

    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(model_bytes)

    if FLAGS.mode == 'first_and_last':
        transpose_first_and_last_convs(graph_def, 'pos_tensor', 'Reshape')
    elif FLAGS.mode == 'all':
        transpose_all_convs(graph_def)
    else:
        raise ValueError('Unexpected transpose mode.')

    minigo_model.write_graph_def(graph_def, metadata, FLAGS.dst_path)
Esempio n. 4
0
def freeze_graph_tpu(model_path):
    """Custom freeze_graph implementation for Cloud TPU."""

    assert model_path
    assert FLAGS.tpu_name
    if FLAGS.tpu_name.startswith('grpc://'):
        tpu_grpc_url = FLAGS.tpu_name
    else:
        tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=None, project=None)
        tpu_grpc_url = tpu_cluster_resolver.get_master()
    sess = tf.Session(tpu_grpc_url)

    output_names = []
    with sess.graph.as_default():
        # Replicate the inference function for each TPU core.
        replicated_features = []
        feature_type = tf.bool if FLAGS.bool_features else tf.float32
        for i in range(FLAGS.num_tpu_cores):
            name = 'pos_tensor_%d' % i
            features = tf.placeholder(feature_type, [None], name=name)
            replicated_features.append((features, ))
        outputs = contrib_tpu.replicate(tpu_model_inference_fn,
                                        replicated_features)

        # The replicate op assigns names like output_0_shard_0 to the output
        # names. Give them human readable names.
        for i, (policy_output, value_output, _) in enumerate(outputs):
            policy_name = 'policy_output_%d' % i
            value_name = 'value_output_%d' % i
            output_names.extend([policy_name, value_name])
            tf.identity(policy_output, policy_name)
            tf.identity(value_output, value_name)

        tf.train.Saver().restore(sess, model_path)

    out_graph = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), output_names)

    metadata = make_model_metadata({
        'engine': 'tpu',
        'num_replicas': FLAGS.num_tpu_cores,
    })

    minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')
Esempio n. 5
0
def freeze_graph(model_path,
                 use_trt=False,
                 trt_max_batch_size=8,
                 trt_precision='fp32'):
    output_names = ['policy_output', 'value_output']

    n = DualNetwork(model_path)
    out_graph = tf.graph_util.convert_variables_to_constants(
        n.sess, n.sess.graph.as_graph_def(), output_names)

    # eval is always fp32, so let's store a eval copy before we trt.
    metadata = make_model_metadata({
        'engine': 'tf',
        'use_trt': False,
    })
    minigo_model.write_graph_def(out_graph, metadata,
                                 model_path + '.evalfp32minigo')

    if use_trt:
        from tensorflow.python.compiler.tensorrt import trt_convert as trt
        converter = trt.TrtGraphConverter(input_graph_def=out_graph,
                                          nodes_blacklist=output_names,
                                          max_batch_size=trt_max_batch_size,
                                          max_workspace_size_bytes=1 << 29,
                                          precision_mode=trt_precision)
        out_graph = converter.convert()

    metadata = make_model_metadata({
        'engine': 'tf',
        'use_trt': bool(use_trt),
    })

    # double buffer model write
    minigo_model.write_graph_def(out_graph, metadata,
                                 model_path + '.stagedmodel')
    minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')