def convert_variables_to_constants_large_model(func): # For large models we use internal tf methods as a hack if tf.__version__.startswith("2.2."): try: from tensorflow.python.framework.convert_to_constants import \ _convert_variables_to_constants_v2_impl # pylint: disable=protected-access except ImportError: _not_implemented_tf_placeholder( "_convert_variables_to_constants_v2_impl")() frozen_graph_def, _ = \ _convert_variables_to_constants_v2_impl(func, lower_control_flow=False, aggressive_inlining=True) return frozen_graph_def try: from tensorflow.python.framework.convert_to_constants import \ _FunctionConverterData, _replace_variables_by_constants # pylint: disable=protected-access except ImportError: _not_implemented_tf_placeholder("_replace_variables_by_constants")() converter_data = _FunctionConverterData(func=func, lower_control_flow=False, aggressive_inlining=True) frozen_graph_def, _ = _replace_variables_by_constants( converter_data=converter_data) return frozen_graph_def
def convert_variables_to_constants_large_model(func): # For large models we use internal tf methods as a hack if tf.__version__.startswith("2.1.") or tf.__version__.startswith("2.0."): from tensorflow.python.framework import convert_to_constants orig_fn = convert_to_constants._construct_concrete_function # pylint: disable=protected-access def fake_construct_fn(func, output_graph_def, converted_input_indices): # Return graph_def without loading it to avoid crash. Will fix errors in graph_def later. return output_graph_def convert_to_constants._construct_concrete_function = fake_construct_fn # pylint: disable=protected-access try: frozen_graph_def = convert_to_constants.convert_variables_to_constants_v2(func, lower_control_flow=False) finally: convert_to_constants._construct_concrete_function = orig_fn # pylint: disable=protected-access return frozen_graph_def if tf.__version__.startswith("2.2."): try: from tensorflow.python.framework.convert_to_constants import \ _convert_variables_to_constants_v2_impl # pylint: disable=protected-access except ImportError: _not_implemented_tf_placeholder("_convert_variables_to_constants_v2_impl")() frozen_graph_def, _ = \ _convert_variables_to_constants_v2_impl(func, lower_control_flow=False, aggressive_inlining=True) return frozen_graph_def try: from tensorflow.python.framework.convert_to_constants import \ _FunctionConverterData, _replace_variables_by_constants # pylint: disable=protected-access except ImportError: _not_implemented_tf_placeholder("_replace_variables_by_constants")() from tensorflow.python.framework import tensor_util, tensor_shape make_tensor_proto_original = tensor_util.make_tensor_proto # Hack to avoid 2GB check def make_tensor_proto_wrapped(values, dtype=None, shape=None, verify_shape=False, allow_broadcast=False): try: return make_tensor_proto_original(values, dtype, shape, verify_shape, allow_broadcast) except ValueError: if dtype is None: dtype = tf.dtypes.as_dtype(values.dtype).as_datatype_enum tensor_proto = tensor_pb2.TensorProto( dtype=dtype, tensor_shape=tensor_shape.as_shape(values.shape).as_proto()) tensor_proto.tensor_content = values.tobytes() return tensor_proto tensor_util.make_tensor_proto = make_tensor_proto_wrapped try: function_converter = _FunctionConverterData if LooseVersion(tf.__version__) >= "2.6.0": from tensorflow.python.eager import context from tensorflow.python.framework.convert_to_constants import _FunctionConverterDataInEager, \ _FunctionConverterDataInGraph if context.executing_eagerly(): function_converter = _FunctionConverterDataInEager else: function_converter = _FunctionConverterDataInGraph else: function_converter = _FunctionConverterData converter_data = function_converter(func=func, lower_control_flow=False, aggressive_inlining=True) frozen_graph_def, _ = _replace_variables_by_constants(converter_data=converter_data) finally: tensor_util.make_tensor_proto = make_tensor_proto_original return frozen_graph_def