def freeze_session(sess, input_names=None, output_names=None, get_tables=False): """Freezes the state of a session into a pruned computation graph.""" output_node_names = [i.split(':')[:-1][0] for i in output_names] keep_var_names = [i.split(':')[:-1][0] for i in input_names] with sess.graph.as_default(): output_node_names = output_node_names or [] output_node_names += [v.op.name for v in tf_global_variables()] output_node_names += keep_var_names graph_def = sess.graph.as_graph_def(add_shapes=True) for node in graph_def.node: node.device = "" graph_def = convert_variables_to_constants(sess, graph_def, output_node_names) table_info = get_hash_table_info(graph_def) if get_tables: initialized_tables = {} tf.tables_initializer().run(session=sess) for info in table_info: h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name) try: k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype) k, v = sess.run([k, v]) initialized_tables[n] = (k, v) except Exception: # pylint: disable=broad-except logger.warning("Could not initialize table with shared_name = %r", n) return graph_def, initialized_tables return graph_def
def from_trackable(trackable, concrete_func, inputs, outputs, large_model): err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model." # Avoid errors due to bug in TF freezing removed_resource_to_placeholder, placeholder_to_resource, graph_captures_copy, func_captures_copy = \ _remove_non_variable_resources_from_captures(concrete_func) try: frozen_graph = from_function(concrete_func, inputs, outputs, large_model) except ValueError as e: if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): raise ValueError(err_large_model) raise e # We might be returning the concrete_func so let's put it back in working order _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy) table_info = get_hash_table_info(frozen_graph) placeholder_to_table_info = {} _get_hash_table_info_from_trackable(trackable, table_info, removed_resource_to_placeholder, placeholder_to_table_info) initialized_tables = {} for info in table_info: if info.shared_name is not None: h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name) n = info.shared_name elif info.resource_input in placeholder_to_resource and info.resource_input not in placeholder_to_table_info: # We found a lookup op with no corresponding HashTable op, but we can associate the placeholder input # from the op with the resource handle from graph captures and make up a shared_name h = placeholder_to_resource[info.resource_input] n = str(uuid.uuid4()).encode() info.shared_name = n placeholder_to_table_info[info.resource_input] = info else: # Found a lookup op but the corresponding HashTable op has already been found and processed. continue try: k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype) initialized_tables[n] = (k.numpy(), v.numpy()) except Exception: # pylint: disable=broad-except logger.warning("Could not initialize table with shared_name = %r", n) for placeholder in removed_resource_to_placeholder.values(): if placeholder not in placeholder_to_table_info: logger.error("Could not find table resource to replace placeholder %s", placeholder) replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info) return frozen_graph, initialized_tables
def from_trackable(trackable, concrete_func, inputs, outputs, large_model): err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model." # Avoid errors due to bug in TF freezing removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \ _remove_non_variable_resources_from_captures(concrete_func) try: frozen_graph = from_function(concrete_func, inputs, outputs, large_model) except ValueError as e: if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): raise ValueError(err_large_model) raise e # We might be returning the concrete_func so let's put it back in working order _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy) table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) placeholder_to_table_info = {} _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes, removed_resource_to_placeholder, placeholder_to_table_info) initialized_tables = {} for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) try: k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) initialized_tables[n] = (k.numpy(), v.numpy()) except Exception: # pylint: disable=broad-except logger.warning("Could not initialize table with shared_name = %r", n) for placeholder in removed_resource_to_placeholder.values(): if placeholder not in placeholder_to_table_info: logger.error("Could not find table resource to replace placeholder %s", placeholder) replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info) return frozen_graph, initialized_tables
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index, large_model): """Load tensorflow graph from saved_model.""" wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve" wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]" wrn_sig_1 = "'--signature_def' not specified, using first signature: %s" err_many_sig = "Cannot load multiple signature defs in TF2.x: %s" err_no_call = "Model doesn't contain usable concrete functions under __call__. Try --signature-def instead." err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]" err_no_sig = "No signatures found in model. Try --concrete_function instead." err_sig_nomatch = "Specified signature not in model %s" err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag." if tag is None: tag = ['serve'] logger.warning(wrn_no_tag) if tag == '': tag = [[]] logger.warning(wrn_empty_tag) utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def)) imported = tf.saved_model.load(model_path, tags=tag) # pylint: disable=no-value-for-parameter all_sigs = imported.signatures.keys() valid_sigs = [s for s in all_sigs if not s.startswith("_")] logger.info("Signatures found in model: %s", "[" + ",".join(valid_sigs) + "].") concrete_func = None if concrete_function_index is not None: utils.make_sure(hasattr(imported, "__call__"), err_no_call) utils.make_sure( concrete_function_index < len( imported.__call__.concrete_functions), err_index, concrete_function_index, len(imported.__call__.concrete_functions) - 1) args, kwargs = imported.__call__.concrete_functions[ concrete_function_index].structured_input_signature concrete_func = imported.__call__.get_concrete_function( *args, **kwargs) elif signature_def: utils.make_sure(signature_def[0] in valid_sigs, err_sig_nomatch, signature_def[0]) concrete_func = imported.signatures[signature_def[0]] else: utils.make_sure(len(valid_sigs) > 0, err_no_sig) logger.warning(wrn_sig_1, valid_sigs[0]) concrete_func = imported.signatures[valid_sigs[0]] tensors_to_rename = {} if input_names is None: inputs = [ tensor.name for tensor in concrete_func.inputs if tensor.dtype != tf.dtypes.resource ] if concrete_func.structured_input_signature is not None: args, kwargs = concrete_func.structured_input_signature structured_inputs = [ t.name for t in args if isinstance(t, tf.TensorSpec) ] + sorted(kwargs.keys()) structured_inputs = set(inp + ":0" for inp in structured_inputs) if any(inp in structured_inputs for inp in inputs): inputs = [inp for inp in inputs if inp in structured_inputs] else: inputs = input_names if output_names is None: outputs = [ tensor.name for tensor in concrete_func.outputs if tensor.dtype != tf.dtypes.resource ] if isinstance(concrete_func.structured_outputs, dict): # outputs are sorted, sort structured_outputs the same way structured_outputs = sorted( concrete_func.structured_outputs.keys()) tensors_to_rename.update(zip(outputs, structured_outputs)) logger.info("Output names: %r", structured_outputs) else: logger.info("Output names: %r", outputs) else: outputs = output_names logger.info( "Outputs not left as None; will use provided names not structured output names." ) # Avoid errors due to bug in TF freezing removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \ _remove_non_variable_resources_from_captures(concrete_func) try: frozen_graph = from_function(concrete_func, inputs, outputs, large_model) except ValueError as e: if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): raise ValueError(err_large_model) raise e # We might be returning the concrete_func so let's put it back in working order _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy) table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) placeholder_to_table_info = {} _get_hash_table_info_from_trackable(imported, table_names, key_dtypes, value_dtypes, removed_resource_to_placeholder, placeholder_to_table_info) initialized_tables = {} for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) try: k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) initialized_tables[n] = (k.numpy(), v.numpy()) except Exception: # pylint: disable=broad-except logger.warning("Could not initialize table with shared_name = %r", n) for placeholder in removed_resource_to_placeholder.values(): if placeholder not in placeholder_to_table_info: logger.error( "Could not find table resource to replace placeholder %s", placeholder) replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info) return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables, tensors_to_rename
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signature_names): """Load tensorflow graph from saved_model.""" wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve" wrn_empty_tag = "'--tag' value is empty string. Using tags = []" wrn_empty_sig = "'--signature_def' not provided. Using all signatures." if tag is None: tag = [tf.saved_model.tag_constants.SERVING] logger.warning(wrn_no_tag) if not signature_names: logger.warning(wrn_empty_sig) if tag == '': tag = [] logger.warning(wrn_empty_tag) if not isinstance(tag, list): tag = [tag] imported = tf.saved_model.loader.load(sess, tag, model_path) signatures = [] for k in imported.signature_def.keys(): if k in signature_names or (not signature_names and not k.startswith("_")): signatures.append(k) try: from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils # pylint: disable=unnecessary-lambda get_signature_def = lambda meta_graph_def, k: \ signature_def_utils.get_signature_def_by_key(meta_graph_def, k) except ImportError: # TF1.12 changed the api get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[ k] if input_names is None: input_names = [] for k in signatures: inputs_tensor_info = get_signature_def(imported, k).inputs for _, input_tensor in inputs_tensor_info.items(): if input_tensor.name not in input_names: input_names.append(input_tensor.name) tensors_to_rename = {} if output_names is None: output_names = [] for k in signatures: outputs_tensor_info = get_signature_def(imported, k).outputs for structured_name, output_tensor in outputs_tensor_info.items(): if output_tensor.name not in output_names: output_names.append(output_tensor.name) tensors_to_rename[output_tensor.name] = structured_name frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names) table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) initialized_tables = {} tf.tables_initializer().run() for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) try: k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) k, v = sess.run([k, v]) initialized_tables[n] = (k, v) except Exception: # pylint: disable=broad-except logger.warning("Could not initialize table with shared_name = %r", n) return frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename
def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeholders, large_model, constant_fold): np.random.seed(1) # Make it reproducible. clean_feed_dict = {utils.node_name(k): v for k, v in feed_dict.items()} if is_tf2() and not as_session: # # use eager to execute the tensorflow func # # numpy doesn't work for all ops, make it tf.Tensor() input_tensors = [ tf.TensorSpec(shape=v.shape, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items() ] input_list = [ tf.convert_to_tensor(v, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items() ] tf.random.set_seed(1) result = func(*input_list) if isinstance(result, (list, tuple)): # list or tuple result = [x.numpy() for x in result] else: # single result result = [result.numpy()] # now make the eager functions a graph concrete_func = tf.function(func, input_signature=tuple(input_tensors)) concrete_func = concrete_func.get_concrete_function() graph_def = from_function(concrete_func, input_names=list(feed_dict.keys()), output_names=outputs, large_model=large_model) initialized_tables = None else: # # use graph to execute the tensorflow func # with tf_session() as sess: tf_set_random_seed(1) input_list = [] if not premade_placeholders: for k, v in clean_feed_dict.items(): input_list.append( tf_placeholder(name=k, shape=v.shape, dtype=tf.as_dtype(v.dtype))) func(*input_list) variables_lib.global_variables_initializer().run() tf_tables_initializer().run() output_dict = [] for out_name in outputs: output_dict.append(sess.graph.get_tensor_by_name(out_name)) result = sess.run(output_dict, feed_dict=feed_dict) graph_def = freeze_session(sess, input_names=list(feed_dict.keys()), output_names=outputs) table_names, key_dtypes, value_dtypes = get_hash_table_info( graph_def) initialized_tables = {} for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) k, v = lookup_ops.lookup_table_export_v2( h, k_dtype, val_dtype) initialized_tables[n] = (sess.run(k), sess.run(v)) tf_reset_default_graph() with tf_session() as sess: tf.import_graph_def(graph_def, name='') graph_def = tf_optimize(list(feed_dict.keys()), outputs, graph_def, fold_constant=constant_fold) model_path = os.path.join( self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") utils.save_protobuf(model_path, graph_def) self.logger.debug("created file %s", model_path) return result, graph_def, initialized_tables
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5, convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False, large_model=False, premade_placeholders=False): # optional - passed to process_tf_graph if process_args is None: process_args = {} # optional - pass distinct feed_dict to onnx runtime if onnx_feed_dict is None: onnx_feed_dict = feed_dict input_names_with_port = list(feed_dict) tf_reset_default_graph() graph_def = None initialized_tables = None np.random.seed(1) # Make it reproducible. clean_feed_dict = {utils.node_name(k): v for k, v in feed_dict.items()} if is_tf2() and not as_session: # # use eager to execute the tensorflow func # # numpy doesn't work for all ops, make it tf.Tensor() input_tensors = [tf.TensorSpec(shape=v.shape, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items()] input_list = [tf.convert_to_tensor(v, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items()] tf.random.set_seed(1) expected = func(*input_list) if isinstance(expected, (list, tuple)): # list or tuple expected = [x.numpy() for x in expected] else: # single result expected = [expected.numpy()] # now make the eager functions a graph concrete_func = tf.function(func, input_signature=tuple(input_tensors)) concrete_func = concrete_func.get_concrete_function() graph_def = from_function(concrete_func, input_names=list(feed_dict.keys()), output_names=output_names_with_port, large_model=large_model) else: # # use graph to execute the tensorflow func # with tf_session() as sess: tf_set_random_seed(1) input_list = [] if not premade_placeholders: for k, v in clean_feed_dict.items(): input_list.append(tf_placeholder(name=k, shape=v.shape, dtype=tf.as_dtype(v.dtype))) func(*input_list) variables_lib.global_variables_initializer().run() tf_tables_initializer().run() output_dict = [] for out_name in output_names_with_port: output_dict.append(sess.graph.get_tensor_by_name(out_name)) expected = sess.run(output_dict, feed_dict=feed_dict) graph_def = freeze_session(sess, input_names=list(feed_dict.keys()), output_names=output_names_with_port) table_names, key_dtypes, value_dtypes = get_hash_table_info(graph_def) initialized_tables = {} for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) initialized_tables[n] = (sess.run(k), sess.run(v)) tf_reset_default_graph() with tf_session() as sess: tf.import_graph_def(graph_def, name='') graph_def = tf_optimize(list(feed_dict.keys()), output_names_with_port, graph_def, fold_constant=constant_fold) tf_reset_default_graph() with tf_session() as sess: const_node_values = None if large_model: const_node_values = compress_graph_def(graph_def) tf.import_graph_def(graph_def, name='') if self.config.is_debug_mode: model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") utils.save_protobuf(model_path, graph_def) self.logger.debug("created file %s", model_path) g = process_tf_graph(sess.graph, opset=self.config.opset, input_names=list(feed_dict.keys()), output_names=output_names_with_port, target=self.config.target, const_node_values=const_node_values, initialized_tables=initialized_tables, **process_args) g = optimizer.optimize_graph(g, catch_errors=False) actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model) for expected_val, actual_val in zip(expected, actual): if check_value: if expected_val.dtype == np.object: decode = np.vectorize(lambda x: x.decode('UTF-8')) expected_val_str = decode(expected_val) self.assertAllEqual(expected_val_str, actual_val) else: self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=atol) if check_dtype: self.assertEqual(expected_val.dtype, actual_val.dtype) # why need shape checke: issue when compare [] with scalar # https://github.com/numpy/numpy/issues/11071 if check_shape: self.assertEqual(expected_val.shape, actual_val.shape) if graph_validator: self.assertTrue(graph_validator(g)) return g
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index, large_model): """Load tensorflow graph from saved_model.""" wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve" wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]" wrn_sig_1 = "'--signature_def' not specified, using first signature: %s" err_many_sig = "Cannot load multiple signature defs in TF2.x: %s" err_no_call = "Model doesn't contain usable concrete functions under __call__. Try --signature-def instead." err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]" err_no_sig = "No signatures found in model. Try --concrete_function instead." err_sig_nomatch = "Specified signature not in model %s" err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag." if tag is None: tag = ['serve'] logger.warning(wrn_no_tag) if tag == '': tag = [[]] logger.warning(wrn_empty_tag) utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def)) imported = tf.saved_model.load(model_path, tags=tag) # pylint: disable=no-value-for-parameter all_sigs = imported.signatures.keys() valid_sigs = [s for s in all_sigs if not s.startswith("_")] logger.info("Signatures found in model: %s", "[" + ",".join(valid_sigs) + "].") concrete_func = None if concrete_function_index is not None: utils.make_sure(hasattr(imported, "__call__"), err_no_call) utils.make_sure( concrete_function_index < len( imported.__call__.concrete_functions), err_index, concrete_function_index, len(imported.__call__.concrete_functions) - 1) sig = imported.__call__.concrete_functions[ concrete_function_index].structured_input_signature[0] concrete_func = imported.__call__.get_concrete_function(*sig) elif signature_def: utils.make_sure(signature_def[0] in valid_sigs, err_sig_nomatch, signature_def[0]) concrete_func = imported.signatures[signature_def[0]] else: utils.make_sure(len(valid_sigs) > 0, err_no_sig) logger.warning(wrn_sig_1, valid_sigs[0]) concrete_func = imported.signatures[valid_sigs[0]] inputs = [ tensor.name for tensor in concrete_func.inputs if tensor.dtype != tf.dtypes.resource ] outputs = [ tensor.name for tensor in concrete_func.outputs if tensor.dtype != tf.dtypes.resource ] # filter by user specified inputs/outputs if input_names: inputs = list(set(input_names) & set(inputs)) if output_names: outputs = list(set(output_names) & set(outputs)) try: frozen_graph = from_function(concrete_func, inputs, outputs, large_model) except ValueError as e: if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): raise ValueError(err_large_model) raise e table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) initialized_tables = {} for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) try: k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) initialized_tables[n] = (k.numpy(), v.numpy()) except Exception: # pylint: disable=broad-except logger.warning("Could not initialize table with shared_name = %r", n) return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index, large_model): """Load tensorflow graph from saved_model.""" wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve" wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]" wrn_sig_1 = "'--signature_def' not specified, using first signature: %s" err_many_sig = "Cannot load multiple signature defs in TF2.x: %s" err_no_call = "Model doesn't contain usable concrete functions under __call__. Try --signature-def instead." err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]" err_no_sig = "No signatures found in model. Try --concrete_function instead." err_sig_nomatch = "Specified signature not in model %s" err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag." if tag is None: tag = ['serve'] logger.warning(wrn_no_tag) if tag == '': tag = [[]] logger.warning(wrn_empty_tag) utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def)) imported = tf.saved_model.load(model_path, tags=tag) # pylint: disable=no-value-for-parameter all_sigs = imported.signatures.keys() valid_sigs = [s for s in all_sigs if not s.startswith("_")] logger.info("Signatures found in model: %s", "[" + ",".join(valid_sigs) + "].") concrete_func = None if concrete_function_index is not None: utils.make_sure(hasattr(imported, "__call__"), err_no_call) utils.make_sure( concrete_function_index < len( imported.__call__.concrete_functions), err_index, concrete_function_index, len(imported.__call__.concrete_functions) - 1) sig = imported.__call__.concrete_functions[ concrete_function_index].structured_input_signature[0] concrete_func = imported.__call__.get_concrete_function(*sig) elif signature_def: utils.make_sure(signature_def[0] in valid_sigs, err_sig_nomatch, signature_def[0]) concrete_func = imported.signatures[signature_def[0]] else: utils.make_sure(len(valid_sigs) > 0, err_no_sig) logger.warning(wrn_sig_1, valid_sigs[0]) concrete_func = imported.signatures[valid_sigs[0]] inputs = [ tensor.name for tensor in concrete_func.inputs if tensor.dtype != tf.dtypes.resource ] outputs = [ tensor.name for tensor in concrete_func.outputs if tensor.dtype != tf.dtypes.resource ] # filter by user specified inputs/outputs if input_names: inputs = list(set(input_names) & set(inputs)) if output_names: outputs = list(set(output_names) & set(outputs)) # Avoid errors due to bug in TF freezing removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \ _remove_non_variable_resources_from_captures(concrete_func) try: frozen_graph = from_function(concrete_func, inputs, outputs, large_model) except ValueError as e: if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): raise ValueError(err_large_model) raise e # We might be returning the concrete_func so let's put it back in working order _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy) table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) placeholder_to_table_info = {} if hasattr(imported, '_table') and hasattr(imported._table, '_create_resource'): # pylint: disable=protected-access # Add tables from saved_model table initializers # pylint: disable=protected-access initializer = imported._table._create_resource.concrete_functions[ 0].function_def new_names, new_k_dtypes, new_v_dtypes = get_hash_table_info( initializer.node_def) table_names.extend(new_names) key_dtypes.extend(new_k_dtypes) value_dtypes.extend(new_v_dtypes) table_handle = id(imported._table.resource_handle) if table_handle in removed_resource_to_placeholder and len( new_names) == 1: table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0]) placeholder_to_table_info[ removed_resource_to_placeholder[table_handle]] = table_info initialized_tables = {} for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) try: k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) initialized_tables[n] = (k.numpy(), v.numpy()) except Exception: # pylint: disable=broad-except logger.warning("Could not initialize table with shared_name = %r", n) replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info) return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables
def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeholders, large_model): np.random.seed(1) # Make it reproducible. clean_feed_dict = {utils.node_name(k): v for k, v in feed_dict.items()} if is_tf2() and not as_session: # # use eager to execute the tensorflow func # # numpy doesn't work for all ops, make it tf.Tensor() input_tensors = [ tf.TensorSpec(shape=v.shape, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items() ] input_list = [ tf.convert_to_tensor(v, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items() ] tf.random.set_seed(1) result = func(*input_list) if isinstance(result, (list, tuple)): # list or tuple result = [x.numpy() for x in result] else: # single result result = [result.numpy()] # now make the eager functions a graph concrete_func = tf.function(func, input_signature=tuple(input_tensors)) concrete_func = concrete_func.get_concrete_function() graph_def = from_function(concrete_func, input_names=list(feed_dict.keys()), output_names=outputs, large_model=large_model) initialized_tables = None else: # # use graph to execute the tensorflow func # with tf_session() as sess: tf_set_random_seed(1) input_list = [] if not premade_placeholders: for k, v in clean_feed_dict.items(): input_list.append( tf_placeholder(name=k, shape=v.shape, dtype=tf.as_dtype(v.dtype))) func(*input_list) variables_lib.global_variables_initializer().run() tf_tables_initializer().run() output_dict = [] for out_name in outputs: output_dict.append(sess.graph.get_tensor_by_name(out_name)) result = sess.run(output_dict, feed_dict=feed_dict) graph_def = freeze_session(sess, input_names=list(feed_dict.keys()), output_names=outputs) table_info = get_hash_table_info(graph_def) initialized_tables = {} for info in table_info: if info.shared_name is None: continue h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name) k, v = lookup_ops.lookup_table_export_v2( h, info.key_dtype, info.val_dtype) initialized_tables[info.shared_name] = (sess.run(k), sess.run(v)) tf_reset_default_graph() with tf_session() as sess: tf.import_graph_def(graph_def, name='') graph_def = tf_optimize(list(feed_dict.keys()), outputs, graph_def) return result, graph_def, initialized_tables