コード例 #1
0
 def Inference(self):
     with tf.name_scope('inference'):
         feed1 = tf.placeholder(name='feed1_node',
                                dtype=tf.float32,
                                shape=[1])
         fetch1 = tf.identity(feed1, name='fetch1_node')
         inference_graph = inference_graph_pb2.InferenceGraph()
         subgraph = inference_graph.subgraphs['default']
         subgraph.feeds['feed1'] = feed1.name
         subgraph.fetches['fetch1'] = fetch1.name
         # Tests that ops are supported.
         subgraph.fetches['fetch_op'] = fetch1.op.name
         return inference_graph
コード例 #2
0
ファイル: predictor.py プロジェクト: zy1620454507/lingvo
def LoadInferenceGraph(path):
  """Parse the given path as an InferenceGraph proto.

  Args:
    path: The path to the file to load.

  Returns:
    An InferenceGraph object.
  """
  inference_graph = inference_graph_pb2.InferenceGraph()
  with tf.gfile.Open(path, "r") as f:
    text_format.Parse(f.read(), inference_graph)
  return inference_graph
コード例 #3
0
 def Inference(self):
     with tf.name_scope('inference'):
         feed1 = tf.placeholder(name='feed1_node',
                                dtype=tf.float32,
                                shape=[1])
         fetch1 = tf.identity(feed1, name='fetch1_node')
         feed2 = tf.placeholder(name='feed2_node',
                                dtype=tf.float32,
                                shape=[2])
         fetch2 = tf.identity(feed2, name='fetch2_node')
         inference_graph = inference_graph_pb2.InferenceGraph()
         subgraph = inference_graph.subgraphs['default']
         subgraph.feeds['feed1'] = feed1.name
         subgraph.fetches['fetch1'] = fetch1.name
         subgraph = inference_graph.subgraphs['subgraph2']
         subgraph.feeds['feed1'] = feed2.name
         subgraph.fetches['fetch1'] = fetch2.name
         return inference_graph
コード例 #4
0
def LoadInferenceGraph(path, clear_device_placement=False):
  """Parse the given path as an InferenceGraph proto.

  Args:
    path: The path to the file to load.
    clear_device_placement: If true, clears device field from nodes in graph.

  Returns:
    An InferenceGraph object.
  """
  inference_graph = inference_graph_pb2.InferenceGraph()
  with tf.io.gfile.GFile(path, "r") as f:
    text_format.Parse(f.read(), inference_graph)
  if clear_device_placement:
    for node in inference_graph.graph_def.node:
      node.ClearField("device")
    for function in inference_graph.graph_def.library.function:
      for node_def in function.node_def:
        node_def.ClearField("device")
  return inference_graph
コード例 #5
0
    def Export(cls,
               model_cfg,
               model_task_name=None,
               device_options=InferenceDeviceOptions(
                   device='',
                   retain_device_placement=False,
                   var_options=None,
                   gen_init_op=True,
                   dtype_override=None),
               freeze_checkpoint=None,
               freeze_defaults=False,
               export_path=None,
               subgraph_filter=None,
               random_seed=None):
        """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
        assert issubclass(model_cfg.cls, base_model.BaseModel)

        # Disable assertions unless user explicitly enables it.
        if FLAGS['enable_asserts'].using_default_value:
            FLAGS.enable_asserts = False

        # TODO(laurenzo): Work out how much we need to specify here in terms of
        # cluster configuration.
        cls._SetClusterParams(model_cfg.cluster, device_options)

        # Disable packed inputs for inference writing purposes.
        def _DisablePackedInput(task):
            if (_ParamExists(task, 'encoder')
                    and _ParamExists(task.encoder, 'packed_input')):
                task.encoder.packed_input = False
            if (_ParamExists(task, 'decoder')
                    and _ParamExists(task.decoder, 'packed_input')):
                task.decoder.packed_input = False

        # Configure the model.
        model_cfg.random_seed = random_seed
        model_cfg.is_eval = True
        model_cfg.is_inference = True

        if issubclass(model_cfg.cls, base_model.MultiTaskModel):
            for _, task_param in model_cfg.task_params.IterParams():
                _DisablePackedInput(task_param)
        else:
            _DisablePackedInput(model_cfg.task)

        tf.logging.info('Model %s. Params: %s', model_cfg.name,
                        model_cfg.ToText())

        # Instantiate the graph.
        graph = tf.Graph()
        with graph.as_default():
            tf.set_random_seed(random_seed)
            cluster = model_cfg.cluster.Instantiate()
            device = cluster.GetPlacer()
            tpu_const_scope = _DummyScope()
            if (IsTpu(device_options)
                    and device_options.var_options == 'AS_CONSTANTS'):
                # Do not specify devices for variables if we are marking them as
                # constants.
                device = ''
                tpu_const_scope = ConstGuaranteeScope()

            with cluster, tf.device(device), tpu_const_scope:

                bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
                    device_options)

                if bfloat16_override:
                    py_utils.UpdateDtype(model_cfg, tf.bfloat16)
                    py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

                # Hard-code TPU-related flags prior to instantiating model.
                old_enable_asserts = FLAGS.enable_asserts
                old_xla_device = FLAGS.xla_device
                if IsTpu(device_options):
                    FLAGS.enable_asserts = False
                    FLAGS.xla_device = 'tpu'

                try:
                    mdl = model_cfg.Instantiate()
                    variables_to_restore = (_MakeVariableDictionary(
                        tf.global_variables()) if not mdl.ema else
                                            mdl.ema.variables_to_restore())

                    if bfloat16_override:
                        saver_var_spec = (
                            bfloat16_variables.
                            get_saver_spec_for_variables_with_bf16_overrides(
                                variables_to_restore))
                    else:
                        saver_var_spec = variables_to_restore

                    saver = tf.train.Saver(saver_var_spec)
                    tf.variables_initializer(tf.global_variables(),
                                             name='init_all_variables')
                    if IsTpu(device_options) and device_options.gen_init_op:
                        tf.group(tf.compat.v1.tpu.initialize_system(),
                                 name='tpu_init_op')

                    model_task = mdl.GetTask(model_task_name)

                    inference_graph_proto = inference_graph_pb2.InferenceGraph(
                    )
                    subgraphs_proto = model_task.Inference()
                    if isinstance(subgraphs_proto, dict):
                        subgraphs_proto = ConvertSubgraphDictToProto(
                            subgraphs_proto)
                    for name, subgraph in subgraphs_proto.subgraphs.items():
                        if not subgraph_filter or name in subgraph_filter:
                            inference_graph_proto.subgraphs[name].CopyFrom(
                                subgraph)

                    # Add a table init op and global variable init op to the graph.
                    # Tables can be declared anywhere in the graph, so this op has to be
                    # added last.
                    tf.tables_initializer(name='init_all_tables')
                finally:
                    # Reset TPU-related flags after model instantiation.
                    FLAGS.enable_asserts = old_enable_asserts
                    FLAGS.xla_device = old_xla_device

        tf.logging.info('Graph contains ops: %r',
                        [op.name for op in graph.get_operations()])

        inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

        # Freezing.
        if freeze_defaults or freeze_checkpoint:
            output_op_names = GetOutputOpNames(graph,
                                               inference_graph_proto,
                                               preserve_colocation_nodes=False)
            if cls._DeviceSupportsFreezing(device_options):
                raise ValueError(
                    'freeze_checkpoint cannot be used with device ' +
                    device_options.device)
            if freeze_checkpoint:
                tf.logging.info('Freezing graph from checkpoint: %s',
                                freeze_checkpoint)
                graph_def = _FreezeGraphFromCheckpoint(graph, saver,
                                                       freeze_checkpoint,
                                                       output_op_names)
            elif freeze_defaults:
                tf.logging.info('Default initializing graph and freezing.')
                graph_def = _FreezeDefaults(graph, output_op_names)
        else:
            output_op_names = GetOutputOpNames(graph, inference_graph_proto)

            # Prune the graph to just the parts we need.
            # To support restoring, we have to not prune out the restore node.
            output_op_names.append('init_all_tables')
            output_op_names.append('init_all_variables')
            output_op_names.append('save/control_dependency')
            output_op_names.append('save/restore_all')
            if IsTpu(device_options) and device_options.gen_init_op:
                output_op_names.append('tpu_init_op')
            graph_def = graph.as_graph_def()
            tf.logging.info('Pruning graph to output ops: %r', output_op_names)
            graph_def = tf.graph_util.extract_sub_graph(
                graph_def, output_op_names)

        if not device_options.retain_device_placement:
            # Clear the device so that the runtime can choose.
            tf.logging.info('Clearing device placement for: %s',
                            device_options.device)
            for node in graph_def.node:
                node.ClearField('device')
            for function in graph_def.library.function:
                for node_def in function.node_def:
                    node_def.ClearField('device')

        inference_graph_proto.graph_def.CopyFrom(graph_def)

        if export_path:
            with tf.gfile.Open(export_path, 'w') as f:
                f.write(text_format.MessageToString(inference_graph_proto))
        return inference_graph_proto
コード例 #6
0
 def Inference(self):
     return inference_graph_pb2.InferenceGraph()
コード例 #7
0
ファイル: predictor.py プロジェクト: gottaMe/self-dsrn
def LoadInferenceGraph(path):
  inference_graph = inference_graph_pb2.InferenceGraph()
  with tf.gfile.Open(path, "r") as f:
    text_format.Parse(f.read(), inference_graph)
  return inference_graph
コード例 #8
0
  def Export(cls,
             model_cfg,
             model_task_name=None,
             device_options=InferenceDeviceOptions(
                 device='',
                 retain_device_placement=False,
                 var_options=None,
                 gen_init_op=True,
                 dtype_override=None,
                 fprop_dtype_override=None),
             freeze_checkpoint=None,
             freeze_defaults=False,
             export_path=None,
             subgraph_filter=None,
             random_seed=None,
             disable_packed_input=True,
             prune_graph=True,
             export_graph_collections=False):
    """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing
    and multi-core inference on TPUs work properly.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A string or a list of subgraph names. If not None or
        empty, export only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.
      prune_graph: If true, prune the graph to just the parts we need.
      export_graph_collections: If true, export graph collections to the
        InferenceGraph proto.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
    if py_utils.IsEagerMode():
      raise ValueError('InferenceGraph exporter does not work in Eager mode.')
    assert issubclass(model_cfg.cls, base_model.BaseModel)
    if device_options.dtype_override and device_options.fprop_dtype_override:
      raise ValueError(
          'device_options{dtype_override,fprop_dtype_override) can not both be'
          'set.')
    if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)):
      subgraph_filter = [subgraph_filter]

    # Disable assertions unless user explicitly enables it.
    if FLAGS['enable_asserts'].using_default_value:
      FLAGS.enable_asserts = False

    # TODO(laurenzo): Work out how much we need to specify here in terms of
    # cluster configuration.
    cls._SetClusterParams(model_cfg.cluster, device_options)

    # Configure the model.
    model_cfg.random_seed = random_seed
    model_cfg.is_inference = True

    if disable_packed_input:

      def _DisablePackedInput(task):
        if (_ParamExists(task, 'encoder') and
            _ParamExists(task.encoder, 'packed_input')):
          task.encoder.packed_input = False
        if (_ParamExists(task, 'decoder') and
            _ParamExists(task.decoder, 'packed_input')):
          task.decoder.packed_input = False

      if issubclass(model_cfg.cls, base_model.MultiTaskModel):
        for _, task_param in model_cfg.task_params.IterParams():
          _DisablePackedInput(task_param)
      else:
        _DisablePackedInput(model_cfg.task)

    tf.logging.debug('Model %s params:', model_cfg.name)
    for line in model_cfg.ToText().split('\n'):
      tf.logging.debug('%s', line)

    # Instantiate the graph.
    graph = tf.Graph()
    with graph.as_default():
      tf.random.set_seed(random_seed)
      cluster = model_cfg.cluster.Instantiate()
      device = cluster.GetPlacer()
      tpu_const_scope = _DummyScope()
      if (IsTpu(device_options) and
          device_options.var_options == 'AS_CONSTANTS'):
        # Do not specify devices for variables if we are marking them as
        # constants.
        device = ''
        tpu_const_scope = ConstGuaranteeScope()

      with cluster, tf.device(device), tpu_const_scope:

        bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
            device_options)

        if bfloat16_override:
          py_utils.UpdateDtype(model_cfg, tf.bfloat16)
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        act_bfloat16_override = ShouldForceBfloat16ForActivations(
            device_options)
        if act_bfloat16_override:
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        # Hard-code TPU-related flags prior to instantiating model.
        old_enable_asserts = FLAGS.enable_asserts
        old_xla_device = FLAGS.xla_device
        if IsTpu(device_options):
          FLAGS.enable_asserts = False
          FLAGS.xla_device = 'tpu'

        try:
          mdl = model_cfg.Instantiate()
          task = mdl.GetTask(model_task_name)

          variables_to_restore = (
              _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else
              mdl.ema.variables_to_restore(mdl.variables_for_ema))

          if bfloat16_override:
            saver_var_spec = (
                bfloat16_variables
                .get_saver_spec_for_variables_with_bf16_overrides(
                    variables_to_restore))
            # For TPU embedding layers, if the table explicitly specifies the
            # inference dtype as bfloat16, the variables in the checkpoint must
            # already be in bfloat16, so we change back to bfloat16 to avoid
            # dtype mismatch.
            for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get()
                             .inference_with_bfloat16_var_names):
              saver_var_spec[var_name] = variables_to_restore[var_name]
          else:
            saver_var_spec = variables_to_restore

          saver = tf.train.Saver(saver_var_spec)
          tf.variables_initializer(
              tf.global_variables(), name='init_all_variables')
          if IsTpu(device_options) and device_options.gen_init_op:
            tf.group(tf.tpu.initialize_system(), name='tpu_init_op')

          if freeze_checkpoint or freeze_defaults:
            # Replace variables with tensors using tf.identity in theta before
            # freezing to avoid the graph referencing types of DT_RESOURCE.
            def AddIdentityToTheta(layer):
              # pylint: disable=protected-access
              layer._private_theta = py_utils.Transform(tf.identity,
                                                        layer._private_theta)
              # pylint: enable=protected-access
              layer.children.Transform(AddIdentityToTheta)

            AddIdentityToTheta(task)

          inference_graph_proto = inference_graph_pb2.InferenceGraph()
          subgraphs_proto = task.Inference()
          if isinstance(subgraphs_proto, dict):
            subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
          for name, subgraph in subgraphs_proto.subgraphs.items():
            if not subgraph_filter or name in subgraph_filter:
              inference_graph_proto.subgraphs[name].CopyFrom(subgraph)

          if not inference_graph_proto.subgraphs and subgraph_filter:
            raise ValueError(
                f'Subgraph filters {subgraph_filter} filtered out all '
                'subgraphs. Defined subgraphs: '
                f'{list(subgraphs_proto.subgraphs.keys())}')

          # Yes, graph collections are bad, however this seems to be the
          # easiest way to get this assets registered from
          # TextFileInitializer.
          assets_collection = tf.compat.v1.get_collection(
              tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
          for asset in assets_collection:
            if asset.op.type == 'Const' and asset.op.get_attr(
                'dtype') == tf.dtypes.string:
              constant_value = asset.op.get_attr('value')
              if constant_value.string_val:
                tf.logging.info('Found asset file_path: %s',
                                constant_value.string_val[0])
                asset_file_def = inference_graph_proto.asset_file_def.add()
                asset_file_def.tensor_info.name = asset.name
                asset_file_def.filename = constant_value.string_val[0]

          # Add a table init op and global variable init op to the graph.
          # Tables can be declared anywhere in the graph, so this op has to be
          # added last.
          tf.tables_initializer(name='init_all_tables')
        finally:
          # Reset TPU-related flags after model instantiation.
          FLAGS.enable_asserts = old_enable_asserts
          FLAGS.xla_device = old_xla_device

    tf.logging.info('Graph contains ops: %r',
                    [op.name for op in graph.get_operations()])

    # Collection defs
    if not tf.executing_eagerly():
      if export_graph_collections:
        meta_graph = tf.train.export_meta_graph(graph=graph)
        for key in meta_graph.collection_def:
          tf.logging.info('copying collection %s', key)
          inference_graph_proto.collection_def[key].CopyFrom(
              meta_graph.collection_def[key])
    else:
      tf.logging.warning('Not exporting collection defs '
                         'since operating in eager mode.')

    # Freezing.
    if freeze_defaults or freeze_checkpoint:
      output_op_names = GetOutputOpNames(
          graph,
          inference_graph_proto,
          preserve_colocation_nodes=False,
          preserve_saver_restore_nodes=False)
      if cls._DeviceSupportsFreezing(device_options):
        raise ValueError('freeze_checkpoint cannot be used with device ' +
                         device_options.device)
      if freeze_checkpoint:
        tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint)
        graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
                                               output_op_names)
      elif freeze_defaults:
        tf.logging.info('Default initializing graph and freezing.')
        graph_def = _FreezeDefaults(graph, output_op_names)
    else:
      inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())
      graph_def = graph.as_graph_def()

      if prune_graph:
        output_op_names = GetOutputOpNames(graph, inference_graph_proto)

        # Prune the graph to just the parts we need.
        # To support restoring, we have to not prune out the restore node.
        output_op_names.append('init_all_tables')
        output_op_names.append('init_all_variables')
        output_op_names.append('save/control_dependency')
        output_op_names.append('save/restore_all')
        if IsTpu(device_options) and device_options.gen_init_op:
          output_op_names.append('tpu_init_op')

        tf.logging.info('Pruning graph to output ops: %r', output_op_names)
        graph_def = tf.compat.v1.graph_util.extract_sub_graph(
            graph_def, output_op_names)

    if not device_options.retain_device_placement:
      # Clear the device so that the runtime can choose.
      tf.logging.info('Clearing device placement for: %s',
                      device_options.device)
      for node in graph_def.node:
        node.ClearField('device')
      for function in graph_def.library.function:
        for node_def in function.node_def:
          node_def.ClearField('device')

    inference_graph_proto.graph_def.CopyFrom(graph_def)

    if export_path:
      with tf.io.gfile.GFile(export_path, 'w') as f:
        f.write(text_format.MessageToString(inference_graph_proto))
    return inference_graph_proto