def generate_back_converters(converters): back_converters = {} for target_name, converter in six.iteritems(converters): if isinstance(converter, partial ) and converter.func == tf_pb_to_tf_py.generic_converter: from_name = converter.keywords['target_name'] revert_inputs = converter.keywords.get('revert_inputs', False) attrib_name_dict = utils.key_value_swapped( converter.keywords.get('attrib_name_dict', {})) attrib_to_input_dict = utils.key_value_swapped( converter.keywords.get('input_to_attrib_dict', {})) attribs_to_remove = list( six.iterkeys(converter.keywords.get('new_attribs', {}))) back_converters[from_name] = partial( generic_converter, target_name=target_name, revert_inputs=revert_inputs, attrib_name_dict=attrib_name_dict, attrib_to_input_dict=attrib_to_input_dict, attribs_to_remove=attribs_to_remove) return back_converters
from __future__ import division, print_function, absolute_import import typing from functools import partial import numpy as np import six from nnef_tools.conversion.tensorflow import tf_pb_to_tf_py from nnef_tools.core import utils from nnef_tools.io.tensorflow.tf_graph import * from nnef_tools.shape_inference import shape_inference as infer # noinspection PyProtectedMember _tf_py_dtype_to_tf_pb_dtype = utils.key_value_swapped( tf_pb_to_tf_py._tf_py_dtype_by_tf_pb_dtype) def convert(tf_graph): # type: (TFGraph)->None for tensor in tf_graph.tensors: if tensor.is_variable: if tensor.data.dtype == np.int64: tensor.data = tensor.data.astype(np.int32) tensor.dtype = "int32" if tensor.data.dtype == np.float64: tensor.data = tensor.data.astype(np.float32) tensor.dtype = "float32"