コード例 #1
0
    def Restore(self, sess, force_reinitialize=False):
        """Restore from latest checkpoint if available, or initialize."""
        # Try and restore from the latest checkpoint.
        if self._RestoreFromLatestCheckpoint(sess):
            # Successfully restored from checkpoint.
            uninitialized_var_names = self._GetUninitializedVarNames(sess)
            assert not uninitialized_var_names, uninitialized_var_names
            return

        # Otherwise we need to initialize.
        uninitialized_var_names = self._GetUninitializedVarNames(sess)
        tf.logging.info('Uninitialized var list: %s', uninitialized_var_names)
        if not force_reinitialize:
            # There should only be uninitialized variables if all variables are
            # uninitialized - with the exception of global_step due to
            # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu.
            all_var_names = [
                six.ensure_binary(v.name[:-2]) for v in tf.global_variables()
            ]
            already_initialized_vars = (set(all_var_names) -
                                        set(uninitialized_var_names))
            already_initialized_vars.discard(b'global_step')
            assert not already_initialized_vars, (
                'Already initialized vars: %s' %
                sorted(already_initialized_vars))

        # At this point all variables are uninitialized, so it is safe to run a
        # global initializer.
        sess.run(self._init_op)
        tf.logging.info('Initialized all vars.')

        # TODO(b/160786085): Move this logic into Overriding vars logic itself,
        # which requires refactoring things out of py_utils to avoid circular deps.
        def _ResolveCkptPath(ckpt_rules):
            return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()}

        # Restore specific variables based on init_from_checkpoint_rules.
        for task in self._model.tasks:
            tp = task.params.train
            if tp.init_from_checkpoint_rules:
                rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
                tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
                py_utils.OverrideVarsFromCheckpoints(sess,
                                                     tf.global_variables(),
                                                     rules)

        if self._params.train.init_from_checkpoint_rules:
            tp = self._params.train
            rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
            tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
            py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(),
                                                 rules)
コード例 #2
0
    def __init__(self, train_dir, model, train_params=None, save_only=False):
        """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     model: A BaseModel instance or None.
     train_params: If specified, use these training params instead of those in
       the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
        self._train_dir = train_dir
        self._save_only = save_only

        self._save_path = os.path.join(self._train_dir, 'ckpt')

        if train_params:
            self._train_params = train_params
            self._model = None
        else:
            assert model
            self._train_params = model.params.train
            self._model = model

        if not self._save_only:
            self._params = model.params
            self._model_tasks = model.tasks
            self._model = model

        self._next_checkpoint_seconds = 0
        self._save_interval_seconds = self._train_params.save_interval_seconds
        self._saver = self._GetSaver()

        self._uninitialized_vars = tf.report_uninitialized_variables(
            tf.global_variables())
コード例 #3
0
  def Restore(self, sess, force_reinitialize=False):
    """Restore from latest checkpoint if available, or initialize."""
    # Try and restore from the latest checkpoint.
    if self._RestoreFromLatestCheckpoint(sess):
      # Successfully restored from checkpoint.
      uninitialized_var_names = self._GetUninitializedVarNames(sess)
      assert not uninitialized_var_names, uninitialized_var_names
      return

    # Otherwise we need to initialize.
    uninitialized_var_names = self._GetUninitializedVarNames(sess)
    tf.logging.info('Uninitialized var list: %s', uninitialized_var_names)
    if not force_reinitialize:
      # There should only be uninitialized variables if all variables are
      # uninitialized - with the exception of global_step due to
      # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu.
      all_var_names = [
          six.ensure_binary(v.name[:-2]) for v in tf.global_variables()
      ]
      already_initialized_vars = (
          set(all_var_names) - set(uninitialized_var_names))
      already_initialized_vars.discard(b'global_step')
      assert not already_initialized_vars, ('Already initialized vars: %s' %
                                            sorted(already_initialized_vars))

    # At this point all variables are uninitialized, so it is safe to run a
    # global initializer.
    sess.run(self._init_op)
    tf.logging.info('Initialized all vars.')

    if self._restore_fns:
      for fn in self._restore_fns:
        fn(sess)
      tf.logging.info('Restored vars using checkpoint rules.')
コード例 #4
0
    def Restore(self, sess, force_reinitialize=False):
        """Restore from latest checkpoint if available, or initialize."""
        # Try and restore from the latest checkpoint.
        if self._RestoreFromLatestCheckpoint(sess):
            # Successfully restored from checkpoint.
            uninitialized_var_names = self._GetUninitializedVarNames(sess)
            assert not uninitialized_var_names, uninitialized_var_names
            return

        # Otherwise we need to initialize.
        uninitialized_var_names = self._GetUninitializedVarNames(sess)
        tf.logging.info('Uninitialized var list: %s', uninitialized_var_names)
        if not force_reinitialize:
            # There should only be uninitialized variables if all variables are
            # uninitialized - with the exception of global_step due to
            # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu.
            all_var_names = [
                six.ensure_binary(v.name[:-2]) for v in tf.global_variables()
            ]
            already_initialized_vars = (set(all_var_names) -
                                        set(uninitialized_var_names))
            already_initialized_vars.discard(b'global_step')
            assert not already_initialized_vars, (
                'Already initialized vars: %s' %
                sorted(already_initialized_vars))

        # At this point all variables are uninitialized, so it is safe to run a
        # global initializer.
        sess.run(tf.global_variables_initializer())
        tf.logging.info('Initialized all vars.')

        # Restore specific variables based on init_from_checkpoint_rules.
        for task in self._model.tasks:
            tp = task.params.train
            if tp.init_from_checkpoint_rules:
                tf.logging.info('OverrideVarsFromCheckpoints %s',
                                tp.init_from_checkpoint_rules)
                py_utils.OverrideVarsFromCheckpoints(
                    sess, tf.global_variables(), tp.init_from_checkpoint_rules)

        if self._params.train.init_from_checkpoint_rules:
            tp = self._params.train
            tf.logging.info('OverrideVarsFromCheckpoints %s',
                            tp.init_from_checkpoint_rules)
            py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(),
                                                 tp.init_from_checkpoint_rules)
コード例 #5
0
ファイル: checkpointer.py プロジェクト: huaxz1986/lingvo
    def __init__(self,
                 train_dir,
                 models,
                 init_op=None,
                 train_params=None,
                 save_only=False):
        """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     models: One or a list of BaseModel instances. Cannot be empty. If there are
       more than one models and `train_params` is None, the save intervals will
       be only determined by the first model.
     init_op: The initialize variables op. If unset, it will call
       tf.global_variables_initializer().
     train_params: If specified, use these training params instead of those in
       the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
        self._train_dir = train_dir
        self._save_only = save_only

        if init_op:
            self._init_op = init_op
        else:
            self._init_op = tf.global_variables_initializer()

        self._save_path = os.path.join(self._train_dir, 'ckpt')

        if not isinstance(models, (list, tuple)):
            models = [models]
        self._models = models

        if train_params:
            self._train_params = train_params
        else:
            self._train_params = models[0].params.train

        self._next_checkpoint_seconds = 0
        self._save_interval_seconds = self._train_params.save_interval_seconds
        self._save_interval_steps = self._train_params.save_interval_steps
        self._prev_ckpt_step = None
        self._saver = self._GetSaver()

        if not py_utils.IsEagerMode():
            self._uninitialized_vars = tf.report_uninitialized_variables(
                tf.global_variables())

        self._BuildInitFromCheckpointRules()
コード例 #6
0
def _WrapNonLingvoVars(dest_layer: base_layer.BaseLayer,
                       variables: Collection[tf.Variable],
                       trainable_variables: Collection[tf.Variable] = ()):
    """Adds variables to the given lingvo layer and appropriate graph collections.

  This function helps wrap variables created outside of lingvo so they are
  correctly handled by lingvo's trainer and checkpointer. It does the following:

    - makes all `variables` trackable through `dest_layer.vars`;
    - ensures `variables` are in the `tf.global_variables()` graph collection so
      the trainer can initialize them;
    - adds the `trainable_variables` subset to the `tf.trainable_variables()`
      graph collection, so they are visible to the learner (i.e. can be
      trained).

  Args:
    dest_layer: Lingvo layer to add the `variables` to.
    variables: The non-lingvo variables to wrap.
    trainable_variables: The subset of `variables` to ensure are trainable.
  """

    global_collection = set(tf.global_variables())
    for v in variables:
        assert v in global_collection
        name = v.name.split(':')[0]
        # pylint: disable=protected-access
        dest_layer._private_vars[name] = v
        with tf.device(v.device):
            dest_layer._private_theta[name] = tf.identity(v)
        # pylint: enable=protected-access

    trainable_collection = set(tf.trainable_variables())
    for v in trainable_variables:
        if v not in trainable_collection:
            tf.logging.warning(
                'Wrapped var %s not in trainable collection; adding it.',
                v.name)
            tf.compat.v1.add_to_collection(
                tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, v)
コード例 #7
0
    def testAccumulator(self):
        # testAccumulator compares
        #   - explicit averaging of independently computed var_grads1 and
        #     var_grads2,
        #   - Accumulator(SGD) optimizer effectively doing this over 2 steps.
        np.random.seed(12345)
        np_input1 = np.random.normal(0.1, 0.5, [2, 4, 3])
        np.random.seed(12346)
        np_input2 = np.random.normal(0.1, 0.5, [2, 4, 3])

        with self.session(use_gpu=True, graph=tf.Graph()) as sess:
            tf.random.set_seed(123456)
            params = layers.ProjectionLayer.Params()
            params.name = 'proj'
            params.dtype = tf.float64
            params.input_dim = 3
            params.output_dim = 2
            params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456)

            params.batch_norm = False
            proj_layer = layers.ProjectionLayer(params)
            inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64)
            in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64)
            inputs2 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64)
            in_padding2 = tf.zeros([2, 4, 1], dtype=tf.float64)
            output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1)
            output2 = proj_layer.FPropDefaultTheta(inputs2, in_padding2)
            loss1 = tf.reduce_sum(output1)
            loss2 = tf.reduce_sum(output2)
            var_grads1 = py_utils.ComputeGradients(loss1, proj_layer.vars)
            var_grads2 = py_utils.ComputeGradients(loss2, proj_layer.vars)
            op = optimizer.SGD.Params()
            opt = op.Instantiate()
            lr = 1e-1
            with tf.control_dependencies([loss1, loss2]):
                var_update_op1 = opt.Apply(
                    lr, py_utils.ApplyGradMultiplier(var_grads1, 1. / 2.))
                with tf.control_dependencies([var_update_op1]):
                    var_update_op2 = opt.Apply(
                        lr, py_utils.ApplyGradMultiplier(var_grads2, 1. / 2.))

            self.evaluate(tf.global_variables_initializer())
            vars1 = self.evaluate(proj_layer.vars.Flatten())
            loss1_1, grads1_1, loss1_2, grads1_2 = sess.run(
                [
                    loss1,
                    var_grads1.Transform(tuple), loss2,
                    var_grads2.Transform(tuple)
                ],
                feed_dict={
                    inputs1: np_input1,
                    inputs2: np_input2,
                },
            )
            sess.run([var_update_op2],
                     feed_dict={
                         inputs1: np_input1,
                         inputs2: np_input2,
                     })
            vars1_1 = self.evaluate(proj_layer.vars.Flatten())

        with self.session(use_gpu=True, graph=tf.Graph()) as sess:
            tf.random.set_seed(123456)
            params = layers.ProjectionLayer.Params()
            params.name = 'proj'
            params.dtype = tf.float64
            params.input_dim = 3
            params.output_dim = 2
            params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456)

            params.batch_norm = False
            proj_layer = layers.ProjectionLayer(params)
            in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64)
            inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64)
            output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1)
            loss = tf.reduce_sum(output1)
            var_grads = py_utils.ComputeGradients(loss, proj_layer.vars)
            op = optimizer.Accumulator.Params().Set(
                accum_steps=2,
                dtype=tf.float64,
                optimizer_tpl=optimizer.SGD.Params())
            opt = op.Instantiate()
            lr = 1e-1
            with cluster_factory.ForTestingWorker(add_summary=True):
                var_update_op = opt.Apply(lr, var_grads)
            increment_global_step_op = tf.assign_add(
                py_utils.GetOrCreateGlobalStepVar(), 1)

            self.evaluate(tf.global_variables_initializer())
            vars2 = self.evaluate(proj_layer.vars.Flatten())
            loss2_1, grads2_1 = sess.run(
                [loss, var_grads.Transform(tuple)],
                feed_dict={
                    inputs1: np_input1,
                })
            loss2_2, grads2_2 = sess.run(
                [loss, var_grads.Transform(tuple)],
                feed_dict={
                    inputs1: np_input2,
                })
            acc_0 = self.evaluate([
                v for v in tf.global_variables()
                if 'grad_accumulator' in v.name
            ])[0]
            sess.run([var_update_op], feed_dict={
                inputs1: np_input1,
            })
            acc_1 = self.evaluate([
                v for v in tf.global_variables()
                if 'grad_accumulator' in v.name
            ])[0]
            vars2_intermediate = self.evaluate(proj_layer.vars.Flatten())
            self.evaluate(increment_global_step_op)
            sess.run([var_update_op], feed_dict={
                inputs1: np_input2,
            })
            acc_2 = self.evaluate([
                v for v in tf.global_variables()
                if 'grad_accumulator' in v.name
            ])[0]
            vars2_1 = self.evaluate(proj_layer.vars.Flatten())

            summary = tf.Summary.FromString(
                self.evaluate(tf.summary.merge_all()))
            tf.logging.info(f'summary: {summary}')
            self.assertEqual(summary.value[0].tag, 'sgd_lr')

        self.assertAllClose(vars1, vars2)

        self.assertAllClose(acc_0, np.zeros_like(acc_0))
        self.assertAllClose(acc_1, grads2_1['w'][1])
        self.assertAllClose(acc_2, np.zeros_like(acc_0))

        self.assertAllClose(loss1_1, loss2_1)
        self.assertAllClose(loss1_2, loss2_2)
        self.assertAllClose(grads1_1, grads2_1)
        self.assertAllClose(grads1_2, grads2_2)

        self.assertAllClose(vars1, vars2_intermediate)

        self.assertAllClose(vars2[0], grads2_1['w'][0])
        self.assertAllClose(vars2[0], grads2_2['w'][0])

        self.assertAllClose(
            vars1[0] - 0.5 * lr * (grads1_1['w'][1] + grads1_2['w'][1]),
            vars1_1[0])

        self.assertAllClose(
            vars2[0] - 0.5 * lr * (grads2_1['w'][1] + grads2_2['w'][1]),
            vars2_1[0])

        self.assertAllClose(vars2, vars2_intermediate)
        self.assertAllClose(vars1_1, vars2_1)
コード例 #8
0
  def Export(cls,
             model_cfg,
             model_task_name=None,
             device_options=InferenceDeviceOptions(
                 device='',
                 retain_device_placement=False,
                 var_options=None,
                 gen_init_op=True,
                 dtype_override=None),
             freeze_checkpoint=None,
             freeze_defaults=False,
             export_path=None,
             subgraph_filter=None,
             random_seed=None,
             disable_packed_input=True):
    """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
    assert issubclass(model_cfg.cls, base_model.BaseModel)

    # Disable assertions unless user explicitly enables it.
    if FLAGS['enable_asserts'].using_default_value:
      FLAGS.enable_asserts = False

    # TODO(laurenzo): Work out how much we need to specify here in terms of
    # cluster configuration.
    cls._SetClusterParams(model_cfg.cluster, device_options)

    # Configure the model.
    model_cfg.random_seed = random_seed
    model_cfg.is_inference = True

    if disable_packed_input:

      def _DisablePackedInput(task):
        if (_ParamExists(task, 'encoder') and
            _ParamExists(task.encoder, 'packed_input')):
          task.encoder.packed_input = False
        if (_ParamExists(task, 'decoder') and
            _ParamExists(task.decoder, 'packed_input')):
          task.decoder.packed_input = False

      if issubclass(model_cfg.cls, base_model.MultiTaskModel):
        for _, task_param in model_cfg.task_params.IterParams():
          _DisablePackedInput(task_param)
      else:
        _DisablePackedInput(model_cfg.task)

    tf.logging.info('Model %s params:', model_cfg.name)
    for line in model_cfg.ToText().split('\n'):
      tf.logging.info('%s', line)

    # Instantiate the graph.
    graph = tf.Graph()
    with graph.as_default():
      tf.random.set_seed(random_seed)
      cluster = model_cfg.cluster.Instantiate()
      device = cluster.GetPlacer()
      tpu_const_scope = _DummyScope()
      if (IsTpu(device_options) and
          device_options.var_options == 'AS_CONSTANTS'):
        # Do not specify devices for variables if we are marking them as
        # constants.
        device = ''
        tpu_const_scope = ConstGuaranteeScope()

      with cluster, tf.device(device), tpu_const_scope:

        bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
            device_options)

        if bfloat16_override:
          py_utils.UpdateDtype(model_cfg, tf.bfloat16)
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        # Hard-code TPU-related flags prior to instantiating model.
        old_enable_asserts = FLAGS.enable_asserts
        old_xla_device = FLAGS.xla_device
        if IsTpu(device_options):
          FLAGS.enable_asserts = False
          FLAGS.xla_device = 'tpu'

        # Ensure the global_step variable is created.
        _ = py_utils.GetOrCreateGlobalStepVar()
        try:
          mdl = model_cfg.Instantiate()
          task = mdl.GetTask(model_task_name)

          variables_to_restore = (
              _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else
              mdl.ema.variables_to_restore(mdl.variables_for_ema))

          if bfloat16_override:
            saver_var_spec = (
                bfloat16_variables
                .get_saver_spec_for_variables_with_bf16_overrides(
                    variables_to_restore))
          else:
            saver_var_spec = variables_to_restore

          saver = tf.train.Saver(saver_var_spec)
          tf.variables_initializer(
              tf.global_variables(), name='init_all_variables')
          if IsTpu(device_options) and device_options.gen_init_op:
            tf.group(tf.tpu.initialize_system(), name='tpu_init_op')

          inference_graph_proto = inference_graph_pb2.InferenceGraph()
          subgraphs_proto = task.Inference()
          if isinstance(subgraphs_proto, dict):
            subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
          for name, subgraph in subgraphs_proto.subgraphs.items():
            if not subgraph_filter or name in subgraph_filter:
              inference_graph_proto.subgraphs[name].CopyFrom(subgraph)

          # Add a table init op and global variable init op to the graph.
          # Tables can be declared anywhere in the graph, so this op has to be
          # added last.
          tf.tables_initializer(name='init_all_tables')
        finally:
          # Reset TPU-related flags after model instantiation.
          FLAGS.enable_asserts = old_enable_asserts
          FLAGS.xla_device = old_xla_device

    tf.logging.info('Graph contains ops: %r',
                         [op.name for op in graph.get_operations()])

    inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

    # Freezing.
    if freeze_defaults or freeze_checkpoint:
      output_op_names = GetOutputOpNames(
          graph, inference_graph_proto, preserve_colocation_nodes=False)
      if cls._DeviceSupportsFreezing(device_options):
        raise ValueError('freeze_checkpoint cannot be used with device ' +
                         device_options.device)
      if freeze_checkpoint:
        tf.logging.info('Freezing graph from checkpoint: %s',
                             freeze_checkpoint)
        graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
                                               output_op_names)
      elif freeze_defaults:
        tf.logging.info('Default initializing graph and freezing.')
        graph_def = _FreezeDefaults(graph, output_op_names)
    else:
      output_op_names = GetOutputOpNames(graph, inference_graph_proto)

      # Prune the graph to just the parts we need.
      # To support restoring, we have to not prune out the restore node.
      output_op_names.append('init_all_tables')
      output_op_names.append('init_all_variables')
      output_op_names.append('save/control_dependency')
      output_op_names.append('save/restore_all')
      if IsTpu(device_options) and device_options.gen_init_op:
        output_op_names.append('tpu_init_op')
      graph_def = graph.as_graph_def()
      tf.logging.info('Pruning graph to output ops: %r', output_op_names)
      graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names)

    if not device_options.retain_device_placement:
      # Clear the device so that the runtime can choose.
      tf.logging.info('Clearing device placement for: %s',
                           device_options.device)
      for node in graph_def.node:
        node.ClearField('device')
      for function in graph_def.library.function:
        for node_def in function.node_def:
          node_def.ClearField('device')

    inference_graph_proto.graph_def.CopyFrom(graph_def)

    if export_path:
      with tf.io.gfile.GFile(export_path, 'w') as f:
        f.write(text_format.MessageToString(inference_graph_proto))
    return inference_graph_proto
コード例 #9
0
  def __init__(self,
               train_dir,
               model,
               init_op=None,
               train_params=None,
               save_only=False):
    """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     model: A BaseModel instance or None.
     init_op: The initialize variables op. If unset, it will call
       tf.global_variables_initializer().
     train_params: If specified, use these training params instead of those in
       the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
    self._train_dir = train_dir
    self._save_only = save_only
    if init_op:
      self._init_op = init_op
    else:
      self._init_op = tf.global_variables_initializer()

    self._save_path = os.path.join(self._train_dir, 'ckpt')

    if train_params:
      self._train_params = train_params
      self._model = None
    else:
      assert model
      self._train_params = model.params.train
      self._model = model

    if self._save_only:
      self._params = None
    else:
      self._params = model.params
      self._model_tasks = model.tasks
      self._model = model

    self._next_checkpoint_seconds = 0
    self._save_interval_seconds = self._train_params.save_interval_seconds
    self._saver = self._GetSaver()

    self._uninitialized_vars = tf.report_uninitialized_variables(
        tf.global_variables())

    # TODO(b/160786085): Move this logic into Overriding vars logic itself,
    # which requires refactoring things out of py_utils to avoid circular deps.
    def _ResolveCkptPath(ckpt_rules):
      return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()}

    self._restore_fns = []

    # Add graph nodes to restore specific variables based on
    # init_from_checkpoint_rules.
    # TODO(b/159267006): Move this back to Restore().
    if self._model:
      for task in self._model.tasks:
        tp = task.params.train
        if tp.init_from_checkpoint_rules:
          rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
          tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
          fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(),
                                                    rules)
          self._restore_fns.append(fn)

    if self._params and self._params.train.init_from_checkpoint_rules:
      tp = self._params.train
      rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
      tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
      fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules)
      self._restore_fns.append(fn)
コード例 #10
0
    def RestoreIfNeeded(self, sess):
        """If vars are not initialized, restore from checkpoint."""
        assert not self._save_only
        uninitialized_var_names = self.GetUninitializedVars(sess)
        # uninitialized_var_names is a list of strings without ":0" suffix.
        # tf.report_uninitialized_variables returns binary strings.
        assert all(
            isinstance(s, six.binary_type) for s in uninitialized_var_names)
        if not uninitialized_var_names:
            # All variables are already initialized.
            return

        tf.logging.info('Uninitialized var list: %s', uninitialized_var_names)

        # There should only be uninitialized variables if all variables are
        # uninitialized.
        all_var_names = [
            six.ensure_binary(v.name[:-2]) for v in tf.global_variables()
        ]
        assert (set(uninitialized_var_names) == set(all_var_names)
                ), sorted(set(all_var_names) - set(uninitialized_var_names))

        if self._Restore(sess):
            # Successfully restored from checkpoint.
            uninitialized_var_names = self.GetUninitializedVars(sess)
            assert not uninitialized_var_names, uninitialized_var_names
            return

        if (self._params.train.init_from_checkpoint_rules
                or any(task.params.train.init_from_checkpoint_rules
                       for task in self._model_tasks)):
            for task in self._model.tasks:
                tp = task.params.train
                if tp.init_from_checkpoint_rules:
                    tf.logging.info('OverrideVarsFromCheckpoints %s',
                                    tp.init_from_checkpoint_rules)
                    py_utils.OverrideVarsFromCheckpoints(
                        sess, tf.global_variables(),
                        tp.init_from_checkpoint_rules)

            if self._params.train.init_from_checkpoint_rules:
                tp = self._params.train
                tf.logging.info('OverrideVarsFromCheckpoints %s',
                                tp.init_from_checkpoint_rules)
                py_utils.OverrideVarsFromCheckpoints(
                    sess, tf.global_variables(), tp.init_from_checkpoint_rules)

            uninitialized_var_names = self.GetUninitializedVars(sess)
            if not uninitialized_var_names:
                return

            tf.logging.info('Remaining uninitialized vars: %s',
                            uninitialized_var_names)

        # Need to retrieve vars, removing ":0" suffix from names.
        uninitialized_vars = [
            v for v in tf.global_variables()
            if six.ensure_binary(v.name[:-2]) in uninitialized_var_names
        ]
        tf.logging.info('Initialize variables: %s',
                        sorted([v.name[:-2] for v in uninitialized_vars]))
        sess.run(tf.variables_initializer(uninitialized_vars))
コード例 #11
0
  def Export(cls,
             model_cfg,
             model_task_name=None,
             device_options=InferenceDeviceOptions(
                 device='',
                 retain_device_placement=False,
                 var_options=None,
                 gen_init_op=True,
                 dtype_override=None,
                 fprop_dtype_override=None),
             freeze_checkpoint=None,
             freeze_defaults=False,
             export_path=None,
             subgraph_filter=None,
             random_seed=None,
             disable_packed_input=True,
             prune_graph=True,
             export_graph_collections=False):
    """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing
    and multi-core inference on TPUs work properly.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A string or a list of subgraph names. If not None or
        empty, export only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.
      prune_graph: If true, prune the graph to just the parts we need.
      export_graph_collections: If true, export graph collections to the
        InferenceGraph proto.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
    if py_utils.IsEagerMode():
      raise ValueError('InferenceGraph exporter does not work in Eager mode.')
    assert issubclass(model_cfg.cls, base_model.BaseModel)
    if device_options.dtype_override and device_options.fprop_dtype_override:
      raise ValueError(
          'device_options{dtype_override,fprop_dtype_override) can not both be'
          'set.')
    if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)):
      subgraph_filter = [subgraph_filter]

    # Disable assertions unless user explicitly enables it.
    if FLAGS['enable_asserts'].using_default_value:
      FLAGS.enable_asserts = False

    # TODO(laurenzo): Work out how much we need to specify here in terms of
    # cluster configuration.
    cls._SetClusterParams(model_cfg.cluster, device_options)

    # Configure the model.
    model_cfg.random_seed = random_seed
    model_cfg.is_inference = True

    if disable_packed_input:

      def _DisablePackedInput(task):
        if (_ParamExists(task, 'encoder') and
            _ParamExists(task.encoder, 'packed_input')):
          task.encoder.packed_input = False
        if (_ParamExists(task, 'decoder') and
            _ParamExists(task.decoder, 'packed_input')):
          task.decoder.packed_input = False

      if issubclass(model_cfg.cls, base_model.MultiTaskModel):
        for _, task_param in model_cfg.task_params.IterParams():
          _DisablePackedInput(task_param)
      else:
        _DisablePackedInput(model_cfg.task)

    tf.logging.debug('Model %s params:', model_cfg.name)
    for line in model_cfg.ToText().split('\n'):
      tf.logging.debug('%s', line)

    # Instantiate the graph.
    graph = tf.Graph()
    with graph.as_default():
      tf.random.set_seed(random_seed)
      cluster = model_cfg.cluster.Instantiate()
      device = cluster.GetPlacer()
      tpu_const_scope = _DummyScope()
      if (IsTpu(device_options) and
          device_options.var_options == 'AS_CONSTANTS'):
        # Do not specify devices for variables if we are marking them as
        # constants.
        device = ''
        tpu_const_scope = ConstGuaranteeScope()

      with cluster, tf.device(device), tpu_const_scope:

        bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
            device_options)

        if bfloat16_override:
          py_utils.UpdateDtype(model_cfg, tf.bfloat16)
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        act_bfloat16_override = ShouldForceBfloat16ForActivations(
            device_options)
        if act_bfloat16_override:
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        # Hard-code TPU-related flags prior to instantiating model.
        old_enable_asserts = FLAGS.enable_asserts
        old_xla_device = FLAGS.xla_device
        if IsTpu(device_options):
          FLAGS.enable_asserts = False
          FLAGS.xla_device = 'tpu'

        try:
          mdl = model_cfg.Instantiate()
          task = mdl.GetTask(model_task_name)

          variables_to_restore = (
              _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else
              mdl.ema.variables_to_restore(mdl.variables_for_ema))

          if bfloat16_override:
            saver_var_spec = (
                bfloat16_variables
                .get_saver_spec_for_variables_with_bf16_overrides(
                    variables_to_restore))
            # For TPU embedding layers, if the table explicitly specifies the
            # inference dtype as bfloat16, the variables in the checkpoint must
            # already be in bfloat16, so we change back to bfloat16 to avoid
            # dtype mismatch.
            for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get()
                             .inference_with_bfloat16_var_names):
              saver_var_spec[var_name] = variables_to_restore[var_name]
          else:
            saver_var_spec = variables_to_restore

          saver = tf.train.Saver(saver_var_spec)
          tf.variables_initializer(
              tf.global_variables(), name='init_all_variables')
          if IsTpu(device_options) and device_options.gen_init_op:
            tf.group(tf.tpu.initialize_system(), name='tpu_init_op')

          if freeze_checkpoint or freeze_defaults:
            # Replace variables with tensors using tf.identity in theta before
            # freezing to avoid the graph referencing types of DT_RESOURCE.
            def AddIdentityToTheta(layer):
              # pylint: disable=protected-access
              layer._private_theta = py_utils.Transform(tf.identity,
                                                        layer._private_theta)
              # pylint: enable=protected-access
              layer.children.Transform(AddIdentityToTheta)

            AddIdentityToTheta(task)

          inference_graph_proto = inference_graph_pb2.InferenceGraph()
          subgraphs_proto = task.Inference()
          if isinstance(subgraphs_proto, dict):
            subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
          for name, subgraph in subgraphs_proto.subgraphs.items():
            if not subgraph_filter or name in subgraph_filter:
              inference_graph_proto.subgraphs[name].CopyFrom(subgraph)

          if not inference_graph_proto.subgraphs and subgraph_filter:
            raise ValueError(
                f'Subgraph filters {subgraph_filter} filtered out all '
                'subgraphs. Defined subgraphs: '
                f'{list(subgraphs_proto.subgraphs.keys())}')

          # Yes, graph collections are bad, however this seems to be the
          # easiest way to get this assets registered from
          # TextFileInitializer.
          assets_collection = tf.compat.v1.get_collection(
              tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
          for asset in assets_collection:
            if asset.op.type == 'Const' and asset.op.get_attr(
                'dtype') == tf.dtypes.string:
              constant_value = asset.op.get_attr('value')
              if constant_value.string_val:
                tf.logging.info('Found asset file_path: %s',
                                constant_value.string_val[0])
                asset_file_def = inference_graph_proto.asset_file_def.add()
                asset_file_def.tensor_info.name = asset.name
                asset_file_def.filename = constant_value.string_val[0]

          # Add a table init op and global variable init op to the graph.
          # Tables can be declared anywhere in the graph, so this op has to be
          # added last.
          tf.tables_initializer(name='init_all_tables')
        finally:
          # Reset TPU-related flags after model instantiation.
          FLAGS.enable_asserts = old_enable_asserts
          FLAGS.xla_device = old_xla_device

    tf.logging.info('Graph contains ops: %r',
                    [op.name for op in graph.get_operations()])

    # Collection defs
    if not tf.executing_eagerly():
      if export_graph_collections:
        meta_graph = tf.train.export_meta_graph(graph=graph)
        for key in meta_graph.collection_def:
          tf.logging.info('copying collection %s', key)
          inference_graph_proto.collection_def[key].CopyFrom(
              meta_graph.collection_def[key])
    else:
      tf.logging.warning('Not exporting collection defs '
                         'since operating in eager mode.')

    # Freezing.
    if freeze_defaults or freeze_checkpoint:
      output_op_names = GetOutputOpNames(
          graph,
          inference_graph_proto,
          preserve_colocation_nodes=False,
          preserve_saver_restore_nodes=False)
      if cls._DeviceSupportsFreezing(device_options):
        raise ValueError('freeze_checkpoint cannot be used with device ' +
                         device_options.device)
      if freeze_checkpoint:
        tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint)
        graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
                                               output_op_names)
      elif freeze_defaults:
        tf.logging.info('Default initializing graph and freezing.')
        graph_def = _FreezeDefaults(graph, output_op_names)
    else:
      inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())
      graph_def = graph.as_graph_def()

      if prune_graph:
        output_op_names = GetOutputOpNames(graph, inference_graph_proto)

        # Prune the graph to just the parts we need.
        # To support restoring, we have to not prune out the restore node.
        output_op_names.append('init_all_tables')
        output_op_names.append('init_all_variables')
        output_op_names.append('save/control_dependency')
        output_op_names.append('save/restore_all')
        if IsTpu(device_options) and device_options.gen_init_op:
          output_op_names.append('tpu_init_op')

        tf.logging.info('Pruning graph to output ops: %r', output_op_names)
        graph_def = tf.compat.v1.graph_util.extract_sub_graph(
            graph_def, output_op_names)

    if not device_options.retain_device_placement:
      # Clear the device so that the runtime can choose.
      tf.logging.info('Clearing device placement for: %s',
                      device_options.device)
      for node in graph_def.node:
        node.ClearField('device')
      for function in graph_def.library.function:
        for node_def in function.node_def:
          node_def.ClearField('device')

    inference_graph_proto.graph_def.CopyFrom(graph_def)

    if export_path:
      with tf.io.gfile.GFile(export_path, 'w') as f:
        f.write(text_format.MessageToString(inference_graph_proto))
    return inference_graph_proto