def extract_weights(graph_def,
                    output_graph,
                    tf_version,
                    signature_def,
                    quantization_dtype=None):
    """Takes a Python GraphDef object and extract the weights.

  Args:
    graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    tf_version: Tensorflow version of the input graph.
    signature_def: the SignatureDef of the inference graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
        compression. Only np.uint8 and np.uint16 are supported.
  """
    constants = [node for node in graph_def.node if node.op == 'Const']
    const_inputs = {}
    # removed the conditional inputs for constants
    for const in constants:
        const_inputs[const.name] = const.input[:]
        del const.input[:]

    print('Writing weight file ' + output_graph + '...')
    const_manifest = []

    graph = tf.Graph()
    fuse_prelu.register_prelu_func(graph)

    with tf.compat.v1.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name='')
        for const in constants:
            tensor = graph.get_tensor_by_name(const.name + ':0')
            value = tensor.eval(session=sess)
            if not isinstance(value, np.ndarray):
                value = np.array(value)

            const_manifest.append({'name': const.name, 'data': value})

            # Restore the conditional inputs
            const.input[:] = const_inputs[const.name]

            # Remove the binary array from tensor and save it to the external file.
            for field_name in CLEARED_TENSOR_FIELDS:
                const.attr["value"].tensor.ClearField(field_name)

    write_artifacts(MessageToDict(graph_def), [const_manifest],
                    output_graph,
                    tf_version,
                    signature_def,
                    quantization_dtype=quantization_dtype)
def optimize_graph(graph,
                   output_node_names,
                   output_graph,
                   tf_version,
                   quantization_dtype=None,
                   skip_op_check=False,
                   strip_debug_ops=False):
    """Takes a Python Graph object and optimizes the graph.

  Args:
    graph: The frozen graph to optimize.
    output_node_names: List of output node names.
    output_graph: The location of the output graph.
    tf_version: Tensorflow version of the input graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
    skip_op_check: Bool whether to skip the op check.
    strip_debug_ops: Bool whether to strip debug ops.
  """
    fuse_prelu.register_prelu_func(graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    for output in output_node_names:
        graph.add_to_collection('train_op',
                                graph.get_operation_by_name(output))

    graph_def = graph.as_graph_def()

    unsupported = validate(graph_def.node, skip_op_check, strip_debug_ops)
    if unsupported:
        raise ValueError('Unsupported Ops in the model before optimization\n' +
                         ', '.join(unsupported))

    # Because TF break the Prelu op into 6 ops, for performance we are
    # fusing those ops into a single prelu
    optimized_graph = fuse_prelu.fuse_ops_for_prelu(graph_def)

    # first pass of grappler optimization, this is needed for batch norm folding.
    config = config_pb2.ConfigProto()
    rewriter_config = config.graph_options.rewrite_options
    rewriter_config.optimizers[:] = [
        'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning',
        'constfold', 'arithmetic', 'dependency'
    ]
    if strip_debug_ops:
        rewriter_config.optimizers.insert(0, 'debug_stripper')

    optimized_graph = _run_grappler(config, optimized_graph, graph)

    # batch norm folding
    optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph)

    # set the device to CPU for all Conv2d nodes, since grappler remap optimizer
    # only support FusedConv2D for CPU.
    for node in optimized_graph.node:
        if node.op == 'Conv2D':
            node.device = '/device:CPU:0'

    # rerun grappler to fuse conv2d
    config.graph_options.rewrite_options.optimizers[:] = [
        'remap', 'constfold', 'arithmetic', 'dependency'
    ]

    optimized_graph = _run_grappler(config, optimized_graph, graph)

    # Since the grappler remap optimizer doe snot support prelu as the activation
    # function for _FusedConv2D op, we are doing it manually here.
    optimized_graph = fuse_prelu.fuse_prelu_with_fused_conv2d(optimized_graph)

    unsupported = validate(optimized_graph.node, skip_op_check,
                           strip_debug_ops)

    if unsupported:
        raise ValueError('Unsupported Ops in the model after optimization\n' +
                         ', '.join(unsupported))

    extract_weights(optimized_graph, output_graph, tf_version,
                    quantization_dtype)
    return optimize_graph