def Inference(self): with tf.name_scope('inference'): feed1 = tf.placeholder(name='feed1_node', dtype=tf.float32, shape=[1]) fetch1 = tf.identity(feed1, name='fetch1_node') inference_graph = inference_graph_pb2.InferenceGraph() subgraph = inference_graph.subgraphs['default'] subgraph.feeds['feed1'] = feed1.name subgraph.fetches['fetch1'] = fetch1.name # Tests that ops are supported. subgraph.fetches['fetch_op'] = fetch1.op.name return inference_graph
def LoadInferenceGraph(path): """Parse the given path as an InferenceGraph proto. Args: path: The path to the file to load. Returns: An InferenceGraph object. """ inference_graph = inference_graph_pb2.InferenceGraph() with tf.gfile.Open(path, "r") as f: text_format.Parse(f.read(), inference_graph) return inference_graph
def Inference(self): with tf.name_scope('inference'): feed1 = tf.placeholder(name='feed1_node', dtype=tf.float32, shape=[1]) fetch1 = tf.identity(feed1, name='fetch1_node') feed2 = tf.placeholder(name='feed2_node', dtype=tf.float32, shape=[2]) fetch2 = tf.identity(feed2, name='fetch2_node') inference_graph = inference_graph_pb2.InferenceGraph() subgraph = inference_graph.subgraphs['default'] subgraph.feeds['feed1'] = feed1.name subgraph.fetches['fetch1'] = fetch1.name subgraph = inference_graph.subgraphs['subgraph2'] subgraph.feeds['feed1'] = feed2.name subgraph.fetches['fetch1'] = fetch2.name return inference_graph
def LoadInferenceGraph(path, clear_device_placement=False): """Parse the given path as an InferenceGraph proto. Args: path: The path to the file to load. clear_device_placement: If true, clears device field from nodes in graph. Returns: An InferenceGraph object. """ inference_graph = inference_graph_pb2.InferenceGraph() with tf.io.gfile.GFile(path, "r") as f: text_format.Parse(f.read(), inference_graph) if clear_device_placement: for node in inference_graph.graph_def.node: node.ClearField("device") for function in inference_graph.graph_def.library.function: for node_def in function.node_def: node_def.ClearField("device") return inference_graph
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ assert issubclass(model_cfg.cls, base_model.BaseModel) # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Disable packed inputs for inference writing purposes. def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_eval = True model_cfg.is_inference = True if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.info('Model %s. Params: %s', model_cfg.name, model_cfg.ToText()) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.set_random_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' try: mdl = model_cfg.Instantiate() variables_to_restore = (_MakeVariableDictionary( tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore()) if bfloat16_override: saver_var_spec = ( bfloat16_variables. get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer(tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.compat.v1.tpu.initialize_system(), name='tpu_init_op') model_task = mdl.GetTask(model_task_name) inference_graph_proto = inference_graph_pb2.InferenceGraph( ) subgraphs_proto = model_task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto( subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom( subgraph) # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames(graph, inference_graph_proto, preserve_colocation_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError( 'freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') graph_def = graph.as_graph_def() tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.graph_util.extract_sub_graph( graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.gfile.Open(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto
def Inference(self): return inference_graph_pb2.InferenceGraph()
def LoadInferenceGraph(path): inference_graph = inference_graph_pb2.InferenceGraph() with tf.gfile.Open(path, "r") as f: text_format.Parse(f.read(), inference_graph) return inference_graph
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None, fprop_dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True, prune_graph=True, export_graph_collections=False): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing and multi-core inference on TPUs work properly. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A string or a list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. prune_graph: If true, prune the graph to just the parts we need. export_graph_collections: If true, export graph collections to the InferenceGraph proto. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ if py_utils.IsEagerMode(): raise ValueError('InferenceGraph exporter does not work in Eager mode.') assert issubclass(model_cfg.cls, base_model.BaseModel) if device_options.dtype_override and device_options.fprop_dtype_override: raise ValueError( 'device_options{dtype_override,fprop_dtype_override) can not both be' 'set.') if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)): subgraph_filter = [subgraph_filter] # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.debug('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.debug('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) act_bfloat16_override = ShouldForceBfloat16ForActivations( device_options) if act_bfloat16_override: py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' try: mdl = model_cfg.Instantiate() task = mdl.GetTask(model_task_name) variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) # For TPU embedding layers, if the table explicitly specifies the # inference dtype as bfloat16, the variables in the checkpoint must # already be in bfloat16, so we change back to bfloat16 to avoid # dtype mismatch. for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get() .inference_with_bfloat16_var_names): saver_var_spec[var_name] = variables_to_restore[var_name] else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') if freeze_checkpoint or freeze_defaults: # Replace variables with tensors using tf.identity in theta before # freezing to avoid the graph referencing types of DT_RESOURCE. def AddIdentityToTheta(layer): # pylint: disable=protected-access layer._private_theta = py_utils.Transform(tf.identity, layer._private_theta) # pylint: enable=protected-access layer.children.Transform(AddIdentityToTheta) AddIdentityToTheta(task) inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) if not inference_graph_proto.subgraphs and subgraph_filter: raise ValueError( f'Subgraph filters {subgraph_filter} filtered out all ' 'subgraphs. Defined subgraphs: ' f'{list(subgraphs_proto.subgraphs.keys())}') # Yes, graph collections are bad, however this seems to be the # easiest way to get this assets registered from # TextFileInitializer. assets_collection = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS) for asset in assets_collection: if asset.op.type == 'Const' and asset.op.get_attr( 'dtype') == tf.dtypes.string: constant_value = asset.op.get_attr('value') if constant_value.string_val: tf.logging.info('Found asset file_path: %s', constant_value.string_val[0]) asset_file_def = inference_graph_proto.asset_file_def.add() asset_file_def.tensor_info.name = asset.name asset_file_def.filename = constant_value.string_val[0] # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) # Collection defs if not tf.executing_eagerly(): if export_graph_collections: meta_graph = tf.train.export_meta_graph(graph=graph) for key in meta_graph.collection_def: tf.logging.info('copying collection %s', key) inference_graph_proto.collection_def[key].CopyFrom( meta_graph.collection_def[key]) else: tf.logging.warning('Not exporting collection defs ' 'since operating in eager mode.') # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False, preserve_saver_restore_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) graph_def = graph.as_graph_def() if prune_graph: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.compat.v1.graph_util.extract_sub_graph( graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto