Пример #1
0
def convert_variables_to_constants_large_model(func):
    # For large models we use internal tf methods as a hack

    if tf.__version__.startswith("2.2."):
        try:
            from tensorflow.python.framework.convert_to_constants import \
                 _convert_variables_to_constants_v2_impl # pylint: disable=protected-access
        except ImportError:
            _not_implemented_tf_placeholder(
                "_convert_variables_to_constants_v2_impl")()
        frozen_graph_def, _ = \
            _convert_variables_to_constants_v2_impl(func, lower_control_flow=False, aggressive_inlining=True)
        return frozen_graph_def

    try:
        from tensorflow.python.framework.convert_to_constants import \
            _FunctionConverterData, _replace_variables_by_constants # pylint: disable=protected-access
    except ImportError:
        _not_implemented_tf_placeholder("_replace_variables_by_constants")()
    converter_data = _FunctionConverterData(func=func,
                                            lower_control_flow=False,
                                            aggressive_inlining=True)
    frozen_graph_def, _ = _replace_variables_by_constants(
        converter_data=converter_data)
    return frozen_graph_def
Пример #2
0
def convert_variables_to_constants_large_model(func):
    # For large models we use internal tf methods as a hack

    if tf.__version__.startswith("2.1.") or tf.__version__.startswith("2.0."):
        from tensorflow.python.framework import convert_to_constants
        orig_fn = convert_to_constants._construct_concrete_function  # pylint: disable=protected-access
        def fake_construct_fn(func, output_graph_def, converted_input_indices):
            # Return graph_def without loading it to avoid crash. Will fix errors in graph_def later.
            return output_graph_def
        convert_to_constants._construct_concrete_function = fake_construct_fn  # pylint: disable=protected-access
        try:
            frozen_graph_def = convert_to_constants.convert_variables_to_constants_v2(func, lower_control_flow=False)
        finally:
            convert_to_constants._construct_concrete_function = orig_fn  # pylint: disable=protected-access
        return frozen_graph_def

    if tf.__version__.startswith("2.2."):
        try:
            from tensorflow.python.framework.convert_to_constants import \
                 _convert_variables_to_constants_v2_impl # pylint: disable=protected-access
        except ImportError:
            _not_implemented_tf_placeholder("_convert_variables_to_constants_v2_impl")()
        frozen_graph_def, _ = \
            _convert_variables_to_constants_v2_impl(func, lower_control_flow=False, aggressive_inlining=True)
        return frozen_graph_def

    try:
        from tensorflow.python.framework.convert_to_constants import \
            _FunctionConverterData, _replace_variables_by_constants # pylint: disable=protected-access
    except ImportError:
        _not_implemented_tf_placeholder("_replace_variables_by_constants")()

    from tensorflow.python.framework import tensor_util, tensor_shape
    make_tensor_proto_original = tensor_util.make_tensor_proto
    # Hack to avoid 2GB check
    def make_tensor_proto_wrapped(values, dtype=None, shape=None, verify_shape=False, allow_broadcast=False):
        try:
            return make_tensor_proto_original(values, dtype, shape, verify_shape, allow_broadcast)
        except ValueError:
            if dtype is None:
                dtype = tf.dtypes.as_dtype(values.dtype).as_datatype_enum
            tensor_proto = tensor_pb2.TensorProto(
                dtype=dtype,
                tensor_shape=tensor_shape.as_shape(values.shape).as_proto())
            tensor_proto.tensor_content = values.tobytes()
            return tensor_proto
    tensor_util.make_tensor_proto = make_tensor_proto_wrapped

    try:
        function_converter = _FunctionConverterData
        if LooseVersion(tf.__version__) >= "2.6.0":
            from tensorflow.python.eager import context
            from tensorflow.python.framework.convert_to_constants import _FunctionConverterDataInEager, \
                _FunctionConverterDataInGraph
            if context.executing_eagerly():
                function_converter = _FunctionConverterDataInEager
            else:
                function_converter = _FunctionConverterDataInGraph
        else:
            function_converter = _FunctionConverterData
        converter_data = function_converter(func=func, lower_control_flow=False, aggressive_inlining=True)
        frozen_graph_def, _ = _replace_variables_by_constants(converter_data=converter_data)
    finally:
        tensor_util.make_tensor_proto = make_tensor_proto_original
    return frozen_graph_def