def __init__(self, model, debug=False, **kwargs): """ TensorFlow 2.x model loader. Parameters ---------- model: Model created with TensorFlow 2.x One of the following model format: - TensorFlow tf.keras.Model object or HDF5 (.h5 or .hdf5) file path - TensorFlow SavedModel directory path - TensorFlow list of concrete functions(s) debug: bool, optional. Defaults to False. This flag should generally be False except for debugging purposes for diagnosing conversion errors. Setting this flag to True will cause graph pass errors to be ignored, forcefully returning a NetworkEnsemble object. kwargs: dict(str, Any), optional Dictionary of additional arguments. """ TFLoader.__init__(self, model, debug, **kwargs) """ tf_ssa graph passes Notes: - "flatten_while_loop_namespaces" should be after "constant_propagation" as it changes node names which constant propagation pass is relying on to perform session.run(), renamed nodes are not understandable for TF. """ self.tfssa_passes = [ constant_propagation, delete_unnecessary_constant_nodes, # delete_unnecessary_constant_nodes must come right after constant_propagation rewrite_control_flow_functions, flatten_sub_graph_namespaces, remove_variable_nodes, fuse_dilation_conv, ]
def __init__(self, model, debug=False, **kwargs): """ TensorFlow 2.x model loader. Parameters ---------- model: Model created with TensorFlow 2.x One of the following model format: - TensorFlow tf.keras.Model object or HDF5 (.h5) file path - TensorFlow SavedModel directory path - TensorFlow list of concrete functions(s) debug: bool, optional. Defaults to False. This flag should generally be False except for debugging purposes for diagnosing conversion errors. Setting this flag to True will cause graph pass errors to be ignored, forcefully returning a NetworkEnsemble object. kwargs: dict(str, Any), optional Dictionary of additional arguments. """ TFLoader.__init__(self, model, debug, **kwargs)