Exemplo n.º 1
0
    def RestoreIfNeeded(self, sess):
        """If vars are not initialized, restore frome checkpoint.

    Args:
      sess: tf.Session.
    """
        assert not self._save_only
        uninitialized_var_names = list(sess.run(self._uninitialized_vars))
        if not uninitialized_var_names:
            return

        tf.logging.info('Uninitialized var list: %s ', uninitialized_var_names)
        if self._Restore(sess):
            return

        if (not any(task.params.train.init_from_checkpoint_rules
                    for task in self._model_tasks)
                and not self._params.train.init_from_checkpoint_rules):
            tf.logging.info('Initialize ALL variables: %s',
                            uninitialized_var_names)
            sess.run([self._initialize_vars])
            tf.logging.info('Initialize variables done.')
            return

        # There was a race in local run. Another thread will get unblocked once
        # _initialize_all is called. OverrideVarsFromCheckpoints
        # might not happen at the right time.
        for task in self._model.tasks:
            tp = task.params.train
            if tp.init_from_checkpoint_rules:
                tf.logging.info('OverrideVarsFromCheckpoints %s',
                                tp.init_from_checkpoint_rules)
                py_utils.OverrideVarsFromCheckpoints(
                    sess, self._vars, tp.init_from_checkpoint_rules)

        if self._params.train.init_from_checkpoint_rules:
            tp = self._params.train
            tf.logging.info('OverrideVarsFromCheckpoints %s',
                            tp.init_from_checkpoint_rules)
            py_utils.OverrideVarsFromCheckpoints(sess, self._vars,
                                                 tp.init_from_checkpoint_rules)

        uninitialized_var_names = list(sess.run(self._uninitialized_vars))
        if not uninitialized_var_names:
            return

        # uninitialized_var_names is a list of strings without ":0" suffix.
        # tf.report_uninitialized_variables returns binary strings.
        assert all(
            isinstance(s, six.binary_type) for s in uninitialized_var_names)

        # Need to retrieve vars, removing ":0" suffix from names.
        uninitialized_vars = [
            v for v in self._vars
            if six.ensure_binary(v.name[:-2]) in uninitialized_var_names
        ]
        tf.logging.info('Initialize variables: %s',
                        [v.name for v in uninitialized_vars])
        sess.run(tf.variables_initializer(uninitialized_vars))
Exemplo n.º 2
0
 def _create_session(self, *args, **kwargs):
     sess = super()._create_session(*args, **kwargs)
     with sess.graph.as_default():
         # Ensure the global_step variable is created in every new session.
         global_step = py_utils.GetOrCreateGlobalStepVar()
         sess.run(
             tf.cond(tf.is_variable_initialized(global_step), tf.no_op,
                     lambda: tf.variables_initializer([global_step])))
     return sess
Exemplo n.º 3
0
  def Export(cls,
             model_cfg,
             model_task_name=None,
             device_options=InferenceDeviceOptions(
                 device='',
                 retain_device_placement=False,
                 var_options=None,
                 gen_init_op=True,
                 dtype_override=None),
             freeze_checkpoint=None,
             freeze_defaults=False,
             export_path=None,
             subgraph_filter=None,
             random_seed=None,
             disable_packed_input=True):
    """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
    assert issubclass(model_cfg.cls, base_model.BaseModel)

    # Disable assertions unless user explicitly enables it.
    if FLAGS['enable_asserts'].using_default_value:
      FLAGS.enable_asserts = False

    # TODO(laurenzo): Work out how much we need to specify here in terms of
    # cluster configuration.
    cls._SetClusterParams(model_cfg.cluster, device_options)

    # Configure the model.
    model_cfg.random_seed = random_seed
    model_cfg.is_inference = True

    if disable_packed_input:

      def _DisablePackedInput(task):
        if (_ParamExists(task, 'encoder') and
            _ParamExists(task.encoder, 'packed_input')):
          task.encoder.packed_input = False
        if (_ParamExists(task, 'decoder') and
            _ParamExists(task.decoder, 'packed_input')):
          task.decoder.packed_input = False

      if issubclass(model_cfg.cls, base_model.MultiTaskModel):
        for _, task_param in model_cfg.task_params.IterParams():
          _DisablePackedInput(task_param)
      else:
        _DisablePackedInput(model_cfg.task)

    tf.logging.info('Model %s params:', model_cfg.name)
    for line in model_cfg.ToText().split('\n'):
      tf.logging.info('%s', line)

    # Instantiate the graph.
    graph = tf.Graph()
    with graph.as_default():
      tf.random.set_seed(random_seed)
      cluster = model_cfg.cluster.Instantiate()
      device = cluster.GetPlacer()
      tpu_const_scope = _DummyScope()
      if (IsTpu(device_options) and
          device_options.var_options == 'AS_CONSTANTS'):
        # Do not specify devices for variables if we are marking them as
        # constants.
        device = ''
        tpu_const_scope = ConstGuaranteeScope()

      with cluster, tf.device(device), tpu_const_scope:

        bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
            device_options)

        if bfloat16_override:
          py_utils.UpdateDtype(model_cfg, tf.bfloat16)
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        # Hard-code TPU-related flags prior to instantiating model.
        old_enable_asserts = FLAGS.enable_asserts
        old_xla_device = FLAGS.xla_device
        if IsTpu(device_options):
          FLAGS.enable_asserts = False
          FLAGS.xla_device = 'tpu'

        # Ensure the global_step variable is created.
        _ = py_utils.GetOrCreateGlobalStepVar()
        try:
          mdl = model_cfg.Instantiate()
          task = mdl.GetTask(model_task_name)

          variables_to_restore = (
              _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else
              mdl.ema.variables_to_restore(mdl.variables_for_ema))

          if bfloat16_override:
            saver_var_spec = (
                bfloat16_variables
                .get_saver_spec_for_variables_with_bf16_overrides(
                    variables_to_restore))
          else:
            saver_var_spec = variables_to_restore

          saver = tf.train.Saver(saver_var_spec)
          tf.variables_initializer(
              tf.global_variables(), name='init_all_variables')
          if IsTpu(device_options) and device_options.gen_init_op:
            tf.group(tf.tpu.initialize_system(), name='tpu_init_op')

          inference_graph_proto = inference_graph_pb2.InferenceGraph()
          subgraphs_proto = task.Inference()
          if isinstance(subgraphs_proto, dict):
            subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
          for name, subgraph in subgraphs_proto.subgraphs.items():
            if not subgraph_filter or name in subgraph_filter:
              inference_graph_proto.subgraphs[name].CopyFrom(subgraph)

          # Add a table init op and global variable init op to the graph.
          # Tables can be declared anywhere in the graph, so this op has to be
          # added last.
          tf.tables_initializer(name='init_all_tables')
        finally:
          # Reset TPU-related flags after model instantiation.
          FLAGS.enable_asserts = old_enable_asserts
          FLAGS.xla_device = old_xla_device

    tf.logging.info('Graph contains ops: %r',
                         [op.name for op in graph.get_operations()])

    inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

    # Freezing.
    if freeze_defaults or freeze_checkpoint:
      output_op_names = GetOutputOpNames(
          graph, inference_graph_proto, preserve_colocation_nodes=False)
      if cls._DeviceSupportsFreezing(device_options):
        raise ValueError('freeze_checkpoint cannot be used with device ' +
                         device_options.device)
      if freeze_checkpoint:
        tf.logging.info('Freezing graph from checkpoint: %s',
                             freeze_checkpoint)
        graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
                                               output_op_names)
      elif freeze_defaults:
        tf.logging.info('Default initializing graph and freezing.')
        graph_def = _FreezeDefaults(graph, output_op_names)
    else:
      output_op_names = GetOutputOpNames(graph, inference_graph_proto)

      # Prune the graph to just the parts we need.
      # To support restoring, we have to not prune out the restore node.
      output_op_names.append('init_all_tables')
      output_op_names.append('init_all_variables')
      output_op_names.append('save/control_dependency')
      output_op_names.append('save/restore_all')
      if IsTpu(device_options) and device_options.gen_init_op:
        output_op_names.append('tpu_init_op')
      graph_def = graph.as_graph_def()
      tf.logging.info('Pruning graph to output ops: %r', output_op_names)
      graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names)

    if not device_options.retain_device_placement:
      # Clear the device so that the runtime can choose.
      tf.logging.info('Clearing device placement for: %s',
                           device_options.device)
      for node in graph_def.node:
        node.ClearField('device')
      for function in graph_def.library.function:
        for node_def in function.node_def:
          node_def.ClearField('device')

    inference_graph_proto.graph_def.CopyFrom(graph_def)

    if export_path:
      with tf.io.gfile.GFile(export_path, 'w') as f:
        f.write(text_format.MessageToString(inference_graph_proto))
    return inference_graph_proto
Exemplo n.º 4
0
    def RestoreIfNeeded(self, sess):
        """If vars are not initialized, restore from checkpoint."""
        assert not self._save_only
        uninitialized_var_names = self.GetUninitializedVars(sess)
        # uninitialized_var_names is a list of strings without ":0" suffix.
        # tf.report_uninitialized_variables returns binary strings.
        assert all(
            isinstance(s, six.binary_type) for s in uninitialized_var_names)
        if not uninitialized_var_names:
            # All variables are already initialized.
            return

        tf.logging.info('Uninitialized var list: %s', uninitialized_var_names)

        # There should only be uninitialized variables if all variables are
        # uninitialized.
        all_var_names = [
            six.ensure_binary(v.name[:-2]) for v in tf.global_variables()
        ]
        assert (set(uninitialized_var_names) == set(all_var_names)
                ), sorted(set(all_var_names) - set(uninitialized_var_names))

        if self._Restore(sess):
            # Successfully restored from checkpoint.
            uninitialized_var_names = self.GetUninitializedVars(sess)
            assert not uninitialized_var_names, uninitialized_var_names
            return

        if (self._params.train.init_from_checkpoint_rules
                or any(task.params.train.init_from_checkpoint_rules
                       for task in self._model_tasks)):
            for task in self._model.tasks:
                tp = task.params.train
                if tp.init_from_checkpoint_rules:
                    tf.logging.info('OverrideVarsFromCheckpoints %s',
                                    tp.init_from_checkpoint_rules)
                    py_utils.OverrideVarsFromCheckpoints(
                        sess, tf.global_variables(),
                        tp.init_from_checkpoint_rules)

            if self._params.train.init_from_checkpoint_rules:
                tp = self._params.train
                tf.logging.info('OverrideVarsFromCheckpoints %s',
                                tp.init_from_checkpoint_rules)
                py_utils.OverrideVarsFromCheckpoints(
                    sess, tf.global_variables(), tp.init_from_checkpoint_rules)

            uninitialized_var_names = self.GetUninitializedVars(sess)
            if not uninitialized_var_names:
                return

            tf.logging.info('Remaining uninitialized vars: %s',
                            uninitialized_var_names)

        # Need to retrieve vars, removing ":0" suffix from names.
        uninitialized_vars = [
            v for v in tf.global_variables()
            if six.ensure_binary(v.name[:-2]) in uninitialized_var_names
        ]
        tf.logging.info('Initialize variables: %s',
                        sorted([v.name[:-2] for v in uninitialized_vars]))
        sess.run(tf.variables_initializer(uninitialized_vars))
Exemplo n.º 5
0
  def Export(cls,
             model_cfg,
             model_task_name=None,
             device_options=InferenceDeviceOptions(
                 device='',
                 retain_device_placement=False,
                 var_options=None,
                 gen_init_op=True,
                 dtype_override=None,
                 fprop_dtype_override=None),
             freeze_checkpoint=None,
             freeze_defaults=False,
             export_path=None,
             subgraph_filter=None,
             random_seed=None,
             disable_packed_input=True,
             prune_graph=True,
             export_graph_collections=False):
    """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing
    and multi-core inference on TPUs work properly.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A string or a list of subgraph names. If not None or
        empty, export only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.
      prune_graph: If true, prune the graph to just the parts we need.
      export_graph_collections: If true, export graph collections to the
        InferenceGraph proto.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
    if py_utils.IsEagerMode():
      raise ValueError('InferenceGraph exporter does not work in Eager mode.')
    assert issubclass(model_cfg.cls, base_model.BaseModel)
    if device_options.dtype_override and device_options.fprop_dtype_override:
      raise ValueError(
          'device_options{dtype_override,fprop_dtype_override) can not both be'
          'set.')
    if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)):
      subgraph_filter = [subgraph_filter]

    # Disable assertions unless user explicitly enables it.
    if FLAGS['enable_asserts'].using_default_value:
      FLAGS.enable_asserts = False

    # TODO(laurenzo): Work out how much we need to specify here in terms of
    # cluster configuration.
    cls._SetClusterParams(model_cfg.cluster, device_options)

    # Configure the model.
    model_cfg.random_seed = random_seed
    model_cfg.is_inference = True

    if disable_packed_input:

      def _DisablePackedInput(task):
        if (_ParamExists(task, 'encoder') and
            _ParamExists(task.encoder, 'packed_input')):
          task.encoder.packed_input = False
        if (_ParamExists(task, 'decoder') and
            _ParamExists(task.decoder, 'packed_input')):
          task.decoder.packed_input = False

      if issubclass(model_cfg.cls, base_model.MultiTaskModel):
        for _, task_param in model_cfg.task_params.IterParams():
          _DisablePackedInput(task_param)
      else:
        _DisablePackedInput(model_cfg.task)

    tf.logging.debug('Model %s params:', model_cfg.name)
    for line in model_cfg.ToText().split('\n'):
      tf.logging.debug('%s', line)

    # Instantiate the graph.
    graph = tf.Graph()
    with graph.as_default():
      tf.random.set_seed(random_seed)
      cluster = model_cfg.cluster.Instantiate()
      device = cluster.GetPlacer()
      tpu_const_scope = _DummyScope()
      if (IsTpu(device_options) and
          device_options.var_options == 'AS_CONSTANTS'):
        # Do not specify devices for variables if we are marking them as
        # constants.
        device = ''
        tpu_const_scope = ConstGuaranteeScope()

      with cluster, tf.device(device), tpu_const_scope:

        bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
            device_options)

        if bfloat16_override:
          py_utils.UpdateDtype(model_cfg, tf.bfloat16)
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        act_bfloat16_override = ShouldForceBfloat16ForActivations(
            device_options)
        if act_bfloat16_override:
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        # Hard-code TPU-related flags prior to instantiating model.
        old_enable_asserts = FLAGS.enable_asserts
        old_xla_device = FLAGS.xla_device
        if IsTpu(device_options):
          FLAGS.enable_asserts = False
          FLAGS.xla_device = 'tpu'

        try:
          mdl = model_cfg.Instantiate()
          task = mdl.GetTask(model_task_name)

          variables_to_restore = (
              _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else
              mdl.ema.variables_to_restore(mdl.variables_for_ema))

          if bfloat16_override:
            saver_var_spec = (
                bfloat16_variables
                .get_saver_spec_for_variables_with_bf16_overrides(
                    variables_to_restore))
            # For TPU embedding layers, if the table explicitly specifies the
            # inference dtype as bfloat16, the variables in the checkpoint must
            # already be in bfloat16, so we change back to bfloat16 to avoid
            # dtype mismatch.
            for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get()
                             .inference_with_bfloat16_var_names):
              saver_var_spec[var_name] = variables_to_restore[var_name]
          else:
            saver_var_spec = variables_to_restore

          saver = tf.train.Saver(saver_var_spec)
          tf.variables_initializer(
              tf.global_variables(), name='init_all_variables')
          if IsTpu(device_options) and device_options.gen_init_op:
            tf.group(tf.tpu.initialize_system(), name='tpu_init_op')

          if freeze_checkpoint or freeze_defaults:
            # Replace variables with tensors using tf.identity in theta before
            # freezing to avoid the graph referencing types of DT_RESOURCE.
            def AddIdentityToTheta(layer):
              # pylint: disable=protected-access
              layer._private_theta = py_utils.Transform(tf.identity,
                                                        layer._private_theta)
              # pylint: enable=protected-access
              layer.children.Transform(AddIdentityToTheta)

            AddIdentityToTheta(task)

          inference_graph_proto = inference_graph_pb2.InferenceGraph()
          subgraphs_proto = task.Inference()
          if isinstance(subgraphs_proto, dict):
            subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
          for name, subgraph in subgraphs_proto.subgraphs.items():
            if not subgraph_filter or name in subgraph_filter:
              inference_graph_proto.subgraphs[name].CopyFrom(subgraph)

          if not inference_graph_proto.subgraphs and subgraph_filter:
            raise ValueError(
                f'Subgraph filters {subgraph_filter} filtered out all '
                'subgraphs. Defined subgraphs: '
                f'{list(subgraphs_proto.subgraphs.keys())}')

          # Yes, graph collections are bad, however this seems to be the
          # easiest way to get this assets registered from
          # TextFileInitializer.
          assets_collection = tf.compat.v1.get_collection(
              tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
          for asset in assets_collection:
            if asset.op.type == 'Const' and asset.op.get_attr(
                'dtype') == tf.dtypes.string:
              constant_value = asset.op.get_attr('value')
              if constant_value.string_val:
                tf.logging.info('Found asset file_path: %s',
                                constant_value.string_val[0])
                asset_file_def = inference_graph_proto.asset_file_def.add()
                asset_file_def.tensor_info.name = asset.name
                asset_file_def.filename = constant_value.string_val[0]

          # Add a table init op and global variable init op to the graph.
          # Tables can be declared anywhere in the graph, so this op has to be
          # added last.
          tf.tables_initializer(name='init_all_tables')
        finally:
          # Reset TPU-related flags after model instantiation.
          FLAGS.enable_asserts = old_enable_asserts
          FLAGS.xla_device = old_xla_device

    tf.logging.info('Graph contains ops: %r',
                    [op.name for op in graph.get_operations()])

    # Collection defs
    if not tf.executing_eagerly():
      if export_graph_collections:
        meta_graph = tf.train.export_meta_graph(graph=graph)
        for key in meta_graph.collection_def:
          tf.logging.info('copying collection %s', key)
          inference_graph_proto.collection_def[key].CopyFrom(
              meta_graph.collection_def[key])
    else:
      tf.logging.warning('Not exporting collection defs '
                         'since operating in eager mode.')

    # Freezing.
    if freeze_defaults or freeze_checkpoint:
      output_op_names = GetOutputOpNames(
          graph,
          inference_graph_proto,
          preserve_colocation_nodes=False,
          preserve_saver_restore_nodes=False)
      if cls._DeviceSupportsFreezing(device_options):
        raise ValueError('freeze_checkpoint cannot be used with device ' +
                         device_options.device)
      if freeze_checkpoint:
        tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint)
        graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
                                               output_op_names)
      elif freeze_defaults:
        tf.logging.info('Default initializing graph and freezing.')
        graph_def = _FreezeDefaults(graph, output_op_names)
    else:
      inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())
      graph_def = graph.as_graph_def()

      if prune_graph:
        output_op_names = GetOutputOpNames(graph, inference_graph_proto)

        # Prune the graph to just the parts we need.
        # To support restoring, we have to not prune out the restore node.
        output_op_names.append('init_all_tables')
        output_op_names.append('init_all_variables')
        output_op_names.append('save/control_dependency')
        output_op_names.append('save/restore_all')
        if IsTpu(device_options) and device_options.gen_init_op:
          output_op_names.append('tpu_init_op')

        tf.logging.info('Pruning graph to output ops: %r', output_op_names)
        graph_def = tf.compat.v1.graph_util.extract_sub_graph(
            graph_def, output_op_names)

    if not device_options.retain_device_placement:
      # Clear the device so that the runtime can choose.
      tf.logging.info('Clearing device placement for: %s',
                      device_options.device)
      for node in graph_def.node:
        node.ClearField('device')
      for function in graph_def.library.function:
        for node_def in function.node_def:
          node_def.ClearField('device')

    inference_graph_proto.graph_def.CopyFrom(graph_def)

    if export_path:
      with tf.io.gfile.GFile(export_path, 'w') as f:
        f.write(text_format.MessageToString(inference_graph_proto))
    return inference_graph_proto