def testEagerMultiLearnerCheckpointCompatibility(self):
    self.assertTrue(tf.executing_eagerly())
    cfg = model_registry.GetParams('test.LinearModelParams', 'Train')
    mdl = cfg.Instantiate()
    with py_utils.GradientTape(persistent=True):
      mdl.ConstructFPropBPropGraph()

    eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1')
    eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2')
    checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl).Save(gsteps=0)
    checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl).Save(gsteps=0)
    eager_v1_keys = _GetCheckpointKeys(
        os.path.join(eager_v1_logdir, 'ckpt_V1', 'ckpt-00000000'))
    eager_v2_keys = _GetCheckpointKeys(
        os.path.join(eager_v2_logdir, 'ckpt_V2', 'ckpt-0'))
    # Expecting two more variables in V2 checkpoints:
    # _CHECKPOINTABLE_OBJECT_GRAPH
    # save_counter
    self.assertEqual(len(eager_v1_keys) + 2, len(eager_v2_keys))  # pylint:disable=g-generic-assert

    py_utils.SetEagerMode(False)
    self.assertFalse(tf.executing_eagerly())
    graph_logdir = os.path.join(self.get_temp_dir(), 'graph')
    os.mkdir(graph_logdir)
    with self.session(graph=tf.Graph()) as sess:
      mdl = cfg.Instantiate()
      for lrn in mdl.GetTask().learners:
        lrn.optimizer.params.clear_variable_scope = False
      mdl.ConstructFPropBPropGraph()
      sess.run(tf.global_variables_initializer())
      checkpointer.Checkpointer(graph_logdir, mdl).Save(sess)
    graph_keys = _GetCheckpointKeys(os.path.join(graph_logdir, 'ckpt'))
    self.assertEqual(eager_v1_keys, graph_keys)
示例#2
0
def GetTensorName(tensor, name_eager=None, i_eager=None):
    """Returns tensor name.

  It is useful for compatibility with eager mode.
  Args:
    tensor: tensor
    name_eager: additional string to append in eager mode
    i_eager: additional index to append in eager mode

  Returns:
    tensor.name in session mode, or concatenation of name_eager, i_eager
      in eager mode
  """
    if not tf.executing_eagerly():
        tensor_name = tensor.name
    else:
        if name_eager and i_eager:
            tensor_name = f'[eager]_{name_eager}_{i_eager}'
        elif name_eager:
            tensor_name = f'[eager]_{name_eager}'
        elif i_eager:
            tensor_name = f'[eager]_{i_eager}'
        else:
            tensor_name = '[eager]'
    return tensor_name
    def testEagerEMACheckpointCompatibility(self):
        self.assertTrue(tf.executing_eagerly())
        cfg = model_registry.GetParams('test.LinearModelParams', 'Train')
        # Use non-zero learning rate so that the weights are updated
        cfg.task.train.learner[0].learning_rate = 0.1
        cfg.task.train.learner[1].learning_rate = 0.1

        eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1')
        eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2')
        mdl = cfg.Instantiate()

        @tf.function
        def _Update():
            with py_utils.GradientTape(persistent=True):
                mdl.ConstructFPropBPropGraph()

        # Step 1
        _Update()
        # Save V1 checkpoints at step 1.
        ckpt_v1 = checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl)
        ckpt_v1.Save(gsteps=1)

        ema = mdl.ema
        model_to_ema_map = _GetModelEMAVariablePairs(mdl, ema)
        model_to_ema_map_snapshot_step1 = {
            k: v.value()
            for k, v in model_to_ema_map.items()
        }

        # Step 2
        _Update()
        # Save V2 checkpoints at step 2.
        ckpt_v2 = checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl)
        ckpt_v2.Save(gsteps=2)

        model_to_ema_map = _GetModelEMAVariablePairs(mdl, ema)
        model_to_ema_map_snapshot_step2 = {
            k: v.value()
            for k, v in model_to_ema_map.items()
        }

        with cluster_factory.SetEval(True):
            # Restores variables to values saved in `eager_v1_logdir`
            ckpt_v1.Restore()
        # Verify that the EMA variables from V1 checkpoints at step 1 successfully
        # overwrite the model variables.
        for v in mdl.variables:
            if v.ref() in model_to_ema_map_snapshot_step1:
                self.assertAllEqual(v,
                                    model_to_ema_map_snapshot_step1[v.ref()])

        with cluster_factory.SetEval(True):
            # Restores variables to values saved in `eager_v2_logdir`
            ckpt_v2.Restore()
        # Verify that the EMA variables from V2 checkpoints at step 2 successfully
        # overwrite the model variables.
        for v in mdl.variables:
            if v.ref() in model_to_ema_map_snapshot_step2:
                self.assertAllEqual(v,
                                    model_to_ema_map_snapshot_step2[v.ref()])
示例#4
0
 def Reset(self, sess):
     if self._dataset:
         if tf.executing_eagerly():
             self._iterator = {
                 key: iter(ds)
                 for key, ds in self._dataset.items()
             }
         else:
             sess.run([it.initializer for it in self._iterator.values()])
示例#5
0
    def _InitIterator(self):
        if self.host_id in self._dataset:
            return

        with py_utils.GlobalStepContext(None):
            # Hide global_step tensor from being captured by dataset function.
            ds = self.GetDataset()
        ds.options().experimental_deterministic = False
        self._dataset[self.host_id] = ds
        if tf.executing_eagerly():
            it = iter(ds)
        else:
            it = tf.data.make_initializable_iterator(ds)
        self._iterator[self.host_id] = it
示例#6
0
  def AddEvalMetric(self, name, value, weight, raise_if_already_added=True):
    """Adds a metric to the eval metrics.

    Args:
      name: A python string. The name of the metric.
      value: A scalar Tensor.
      weight: A scalar Tensor.
      raise_if_already_added: If the metric already exists, raise a ValueError.

    Raises:
      ValueError: if `name` is already defined.

    """
    if name in self._eval_metrics and not tf.executing_eagerly():
      if raise_if_already_added:
        raise ValueError('Metric %s has already been defined.' % name)
    self._eval_metrics[name] = (value, weight)
示例#7
0
def PlotSequenceFeatures(plots, name, **kwargs):
    """Plots a stack of sequence features.

  Args:
    plots: A list of tuple (tensor, seq_len), as returned by
      PrepareSequenceForPlot().
    name: A string for the caption of the plot.
    **kwargs: Keyword arguments passed to AddSubplot().
  """
    if not _ShouldAddSummary():
        return

    with plot.MatplotlibFigureSummary(name,
                                      figsize=(8, len(plots) * 3.5)) as fig:
        for i, (tensor, seq_len) in enumerate(plots):
            if not tf.executing_eagerly():
                tensor_name = tensor.name
            else:
                tensor_name = f'[eager]_{name}_{i}'

            fig.AddSubplot([tensor, seq_len],
                           TrimPaddingAndPlotSequence,
                           title=tensor_name,
                           **kwargs)
示例#8
0
    def __init__(self, train_cfg, ps_params_dict, *args, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      *args: List args to pass through to BaseRunner.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        if py_utils.IsEagerMode():
            assert tf.executing_eagerly()
            tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}')

            # Connect to the TPU runtime.
            resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):])
            tf.config.experimental_connect_to_cluster(resolver)

        super().__init__(train_cfg, *args, **kwargs)

        data_parallelism = self._cluster.num_splits_per_client
        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(self._logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        self._WriteToLog(
            text_format.MessageToString(train_cfg.ToProto(), as_utf8=True),
            self._checkpoint_dir, 'trainer_params.pbtxt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                if py_utils.IsEagerMode():
                    topology = tf.tpu.experimental.initialize_tpu_system(
                        resolver)
                else:
                    # tpu.initialize_system() is called with None as embedding_config, as
                    # embedding_config is not available yet. Later in _Loop, it is called
                    # with the correct embedding_config. Since it cannot be called twice
                    # in the same graph with different embedding_config, we use a
                    # dummy_graph here.
                    dummy_graph = tf.Graph()
                    with dummy_graph.as_default():
                        tpu_initialize_system_op = tf.tpu.initialize_system(
                            embedding_config=None, job=job)

                    with self._GetSession(graph=dummy_graph) as sess:
                        topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_computation_shape is None:
                    computation_shape = py_utils.ComputationShape(
                        num_devices_per_split, topology)
                else:
                    computation_shape = train_cfg.train.tpu_computation_shape
                    assert num_devices_per_split == np.prod(computation_shape)

                if train_cfg.train.tpu_device_order_mode is None:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism)
                else:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(self.device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(self.device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(self.device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []
        self._ckpt_programs = []

        self._checkpoint_to_load = None
        with self._cluster:
            # Create the ExponentialMovingAverage singleton shared by all programs, if
            # applicable.
            ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var)
            for task_string, program_schedule_params in ps_params_dict.items():
                program_schedule_params.logdir = self._logdir
                program_schedule_params.num_splits_per_client = data_parallelism
                program_schedule_params.task_name = task_string
                # If the model was created above, we'll inject it here as a
                # shared_model.
                ps = program_schedule_params.Instantiate(
                    shared_model=shared_model,
                    trial=self._trial,
                    ema=ema,
                    tf_master=self._tf_master)
                self._program_schedule_dict[task_string] = ps
                tf.logging.info('program_schedule_params: %s',
                                program_schedule_params.ToText())
                self._programs += ps.Programs()
                if ps.train_program:
                    self._ckpt_programs.append(ps.train_program)
                else:
                    self._ckpt_programs += ps.Programs()
                if program_schedule_params.ml_perf.benchmark_name is not None:
                    self._ml_perf = program_schedule_params.ml_perf
                if ('checkpoint_to_load' in program_schedule_params
                        and program_schedule_params.checkpoint_to_load):
                    if (self._checkpoint_to_load
                            and (self._checkpoint_to_load !=
                                 program_schedule_params.checkpoint_to_load)):
                        raise ValueError(
                            f'Multiple values found for checkpoint_to_load: '
                            f'{self._checkpoint_to_load}, '
                            f'{program_schedule_params.checkpoint_to_load}.')
                    self._checkpoint_to_load = program_schedule_params.checkpoint_to_load

        tf.logging.info('num_programs: %d', len(self._programs))

        # When running in a vizier trainer, the executor reports infeasiable runs
        # in case of errors. The programs report metrics and normal completions.
        for program in self._programs:
            if program._should_report_metrics:
                self._should_report_metrics = True

        with self._cluster, tf.container(
                self._container_id), contextlib.ExitStack() as stack:
            if not py_utils.IsEagerMode():
                stack.enter_context(self._graph.as_default())

                if FLAGS.use_tpu_mirrored_vars:
                    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                        FLAGS.tf_master,
                        job_name=FLAGS.worker_job[len('/job:'):])
                    self._tpu_strategy = tf.distribute.experimental.TPUStrategy(
                        resolver, device_assignment=self.device_assignment)
                    stack.enter_context(self._tpu_strategy.scope())
                    stack.enter_context(
                        tpu_strategy._TPUReplicaContext(self._tpu_strategy))
                else:
                    stack.enter_context(tf.device(self._cluster.GetPlacer()))

            if FLAGS.pdb_on_exception:
                stack.enter_context(pdb_wrapper.catch_post_mortem())
            with py_utils.VariableStore(), py_utils.VariableRenameScope(
                    self._variable_renaming_rules):
                # `BuildTpuSubgraph` has to be called before checkpoint restore, so that
                # the optimizer slot variables are guaranteed to be initialized before
                # they get loaded. Otherwise, the optimizers' slot variables will not
                # be properly loaded when V1 checkpoint is used.
                for program in self._programs:
                    program.BuildTpuSubgraph()
                    py_utils.ClearTpuSummaryTensors()

            if not py_utils.IsEagerMode():
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

            checkpointer_models = [
                program.GetModel() for program in self._ckpt_programs
            ]

            if py_utils.IsEagerMode():
                if FLAGS.use_v2_checkpoints_in_eager:
                    self._checkpointer = checkpointer.EagerCheckpointerV2(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
                else:
                    self._checkpointer = checkpointer.EagerCheckpointerV1(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
            else:
                self._checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    models=checkpointer_models,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=False)

            for program in self._programs:
                program.SetStatusMessageFn(self._SetStatusMessage)

            tpu_embedding_collection = (
                tpu_embedding_layers.TpuEmbeddingCollection.Get())
            self._load_ops = tpu_embedding_collection.load_ops
            self._retrieve_ops = tpu_embedding_collection.retrieve_ops
            self._tpu_embedding = tpu_embedding_collection.tpu_embedding
示例#9
0
 def Initialize(self, sess):
     if not tf.executing_eagerly():
         self.Reset(sess)
     super().Initialize(sess)
示例#10
0
 def AddPerExampleTensor(self, name, value):
   if name in self._per_example and not tf.executing_eagerly():
     raise ValueError('Metric %s has already been defined.' % name)
   self._per_example[name] = value
示例#11
0
文件: cluster.py 项目: vcj-huy/lingvo
 def InitDevicesEager(self):
   assert tf.executing_eagerly()
   self._session_devices = [d.name for d in tf.config.list_logical_devices()]
   tf.logging.info('InitDevices %s' % sorted(self._session_devices))
示例#12
0
import lingvo.compat as tf
from lingvo.core import cluster_factory
from lingvo.core import py_utils
import numpy as np

FLAGS = tf.flags.FLAGS

# Enable tf.function when eager execution is on-by-default, which is the case
# when:
# - the test target doesn't depend on the disable_tf2 target, and
# - --define=tf_api_version=1 is not specified during the build.
#
# TODO(laigd): remove TF version check when 312743821 and 313682500 are in the
# release.
if tf.executing_eagerly() and tf.compat.v1.__version__ >= '2.3.0':
  try:
    FLAGS.if_use_tf_function = True
    FLAGS.while_loop_use_tf_function = True
    FLAGS.call_defun_use_tf_function = True
  except tf.flags.UnrecognizedFlagError:
    pass

# Disable eager execution for all tests.
tf.disable_eager_execution()

tf.flags.DEFINE_boolean(
    'update_goldens', False,
    'Update the goldens, rather than diffing against them.')

示例#13
0
 def testSomeTFSymbols(self):
     self.assertFalse(tf.executing_eagerly())
     self.assertIsNotNone(tf.logging)
     self.assertIsNotNone(tf.flags)
     self.assertIs(tf.Defun, function.Defun)
  def Export(cls,
             model_cfg,
             model_task_name=None,
             device_options=InferenceDeviceOptions(
                 device='',
                 retain_device_placement=False,
                 var_options=None,
                 gen_init_op=True,
                 dtype_override=None,
                 fprop_dtype_override=None),
             freeze_checkpoint=None,
             freeze_defaults=False,
             export_path=None,
             subgraph_filter=None,
             random_seed=None,
             disable_packed_input=True,
             prune_graph=True,
             export_graph_collections=False):
    """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing
    and multi-core inference on TPUs work properly.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A string or a list of subgraph names. If not None or
        empty, export only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.
      prune_graph: If true, prune the graph to just the parts we need.
      export_graph_collections: If true, export graph collections to the
        InferenceGraph proto.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
    if py_utils.IsEagerMode():
      raise ValueError('InferenceGraph exporter does not work in Eager mode.')
    assert issubclass(model_cfg.cls, base_model.BaseModel)
    if device_options.dtype_override and device_options.fprop_dtype_override:
      raise ValueError(
          'device_options{dtype_override,fprop_dtype_override) can not both be'
          'set.')
    if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)):
      subgraph_filter = [subgraph_filter]

    # Disable assertions unless user explicitly enables it.
    if FLAGS['enable_asserts'].using_default_value:
      FLAGS.enable_asserts = False

    # TODO(laurenzo): Work out how much we need to specify here in terms of
    # cluster configuration.
    cls._SetClusterParams(model_cfg.cluster, device_options)

    # Configure the model.
    model_cfg.random_seed = random_seed
    model_cfg.is_inference = True

    if disable_packed_input:

      def _DisablePackedInput(task):
        if (_ParamExists(task, 'encoder') and
            _ParamExists(task.encoder, 'packed_input')):
          task.encoder.packed_input = False
        if (_ParamExists(task, 'decoder') and
            _ParamExists(task.decoder, 'packed_input')):
          task.decoder.packed_input = False

      if issubclass(model_cfg.cls, base_model.MultiTaskModel):
        for _, task_param in model_cfg.task_params.IterParams():
          _DisablePackedInput(task_param)
      else:
        _DisablePackedInput(model_cfg.task)

    tf.logging.debug('Model %s params:', model_cfg.name)
    for line in model_cfg.ToText().split('\n'):
      tf.logging.debug('%s', line)

    # Instantiate the graph.
    graph = tf.Graph()
    with graph.as_default():
      tf.random.set_seed(random_seed)
      cluster = model_cfg.cluster.Instantiate()
      device = cluster.GetPlacer()
      tpu_const_scope = _DummyScope()
      if (IsTpu(device_options) and
          device_options.var_options == 'AS_CONSTANTS'):
        # Do not specify devices for variables if we are marking them as
        # constants.
        device = ''
        tpu_const_scope = ConstGuaranteeScope()

      with cluster, tf.device(device), tpu_const_scope:

        bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
            device_options)

        if bfloat16_override:
          py_utils.UpdateDtype(model_cfg, tf.bfloat16)
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        act_bfloat16_override = ShouldForceBfloat16ForActivations(
            device_options)
        if act_bfloat16_override:
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        # Hard-code TPU-related flags prior to instantiating model.
        old_enable_asserts = FLAGS.enable_asserts
        old_xla_device = FLAGS.xla_device
        if IsTpu(device_options):
          FLAGS.enable_asserts = False
          FLAGS.xla_device = 'tpu'

        try:
          mdl = model_cfg.Instantiate()
          task = mdl.GetTask(model_task_name)

          variables_to_restore = (
              _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else
              mdl.ema.variables_to_restore(mdl.variables_for_ema))

          if bfloat16_override:
            saver_var_spec = (
                bfloat16_variables
                .get_saver_spec_for_variables_with_bf16_overrides(
                    variables_to_restore))
            # For TPU embedding layers, if the table explicitly specifies the
            # inference dtype as bfloat16, the variables in the checkpoint must
            # already be in bfloat16, so we change back to bfloat16 to avoid
            # dtype mismatch.
            for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get()
                             .inference_with_bfloat16_var_names):
              saver_var_spec[var_name] = variables_to_restore[var_name]
          else:
            saver_var_spec = variables_to_restore

          saver = tf.train.Saver(saver_var_spec)
          tf.variables_initializer(
              tf.global_variables(), name='init_all_variables')
          if IsTpu(device_options) and device_options.gen_init_op:
            tf.group(tf.tpu.initialize_system(), name='tpu_init_op')

          if freeze_checkpoint or freeze_defaults:
            # Replace variables with tensors using tf.identity in theta before
            # freezing to avoid the graph referencing types of DT_RESOURCE.
            def AddIdentityToTheta(layer):
              # pylint: disable=protected-access
              layer._private_theta = py_utils.Transform(tf.identity,
                                                        layer._private_theta)
              # pylint: enable=protected-access
              layer.children.Transform(AddIdentityToTheta)

            AddIdentityToTheta(task)

          inference_graph_proto = inference_graph_pb2.InferenceGraph()
          subgraphs_proto = task.Inference()
          if isinstance(subgraphs_proto, dict):
            subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
          for name, subgraph in subgraphs_proto.subgraphs.items():
            if not subgraph_filter or name in subgraph_filter:
              inference_graph_proto.subgraphs[name].CopyFrom(subgraph)

          if not inference_graph_proto.subgraphs and subgraph_filter:
            raise ValueError(
                f'Subgraph filters {subgraph_filter} filtered out all '
                'subgraphs. Defined subgraphs: '
                f'{list(subgraphs_proto.subgraphs.keys())}')

          # Yes, graph collections are bad, however this seems to be the
          # easiest way to get this assets registered from
          # TextFileInitializer.
          assets_collection = tf.compat.v1.get_collection(
              tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
          for asset in assets_collection:
            if asset.op.type == 'Const' and asset.op.get_attr(
                'dtype') == tf.dtypes.string:
              constant_value = asset.op.get_attr('value')
              if constant_value.string_val:
                tf.logging.info('Found asset file_path: %s',
                                constant_value.string_val[0])
                asset_file_def = inference_graph_proto.asset_file_def.add()
                asset_file_def.tensor_info.name = asset.name
                asset_file_def.filename = constant_value.string_val[0]

          # Add a table init op and global variable init op to the graph.
          # Tables can be declared anywhere in the graph, so this op has to be
          # added last.
          tf.tables_initializer(name='init_all_tables')
        finally:
          # Reset TPU-related flags after model instantiation.
          FLAGS.enable_asserts = old_enable_asserts
          FLAGS.xla_device = old_xla_device

    tf.logging.info('Graph contains ops: %r',
                    [op.name for op in graph.get_operations()])

    # Collection defs
    if not tf.executing_eagerly():
      if export_graph_collections:
        meta_graph = tf.train.export_meta_graph(graph=graph)
        for key in meta_graph.collection_def:
          tf.logging.info('copying collection %s', key)
          inference_graph_proto.collection_def[key].CopyFrom(
              meta_graph.collection_def[key])
    else:
      tf.logging.warning('Not exporting collection defs '
                         'since operating in eager mode.')

    # Freezing.
    if freeze_defaults or freeze_checkpoint:
      output_op_names = GetOutputOpNames(
          graph,
          inference_graph_proto,
          preserve_colocation_nodes=False,
          preserve_saver_restore_nodes=False)
      if cls._DeviceSupportsFreezing(device_options):
        raise ValueError('freeze_checkpoint cannot be used with device ' +
                         device_options.device)
      if freeze_checkpoint:
        tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint)
        graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
                                               output_op_names)
      elif freeze_defaults:
        tf.logging.info('Default initializing graph and freezing.')
        graph_def = _FreezeDefaults(graph, output_op_names)
    else:
      inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())
      graph_def = graph.as_graph_def()

      if prune_graph:
        output_op_names = GetOutputOpNames(graph, inference_graph_proto)

        # Prune the graph to just the parts we need.
        # To support restoring, we have to not prune out the restore node.
        output_op_names.append('init_all_tables')
        output_op_names.append('init_all_variables')
        output_op_names.append('save/control_dependency')
        output_op_names.append('save/restore_all')
        if IsTpu(device_options) and device_options.gen_init_op:
          output_op_names.append('tpu_init_op')

        tf.logging.info('Pruning graph to output ops: %r', output_op_names)
        graph_def = tf.compat.v1.graph_util.extract_sub_graph(
            graph_def, output_op_names)

    if not device_options.retain_device_placement:
      # Clear the device so that the runtime can choose.
      tf.logging.info('Clearing device placement for: %s',
                      device_options.device)
      for node in graph_def.node:
        node.ClearField('device')
      for function in graph_def.library.function:
        for node_def in function.node_def:
          node_def.ClearField('device')

    inference_graph_proto.graph_def.CopyFrom(graph_def)

    if export_path:
      with tf.io.gfile.GFile(export_path, 'w') as f:
        f.write(text_format.MessageToString(inference_graph_proto))
    return inference_graph_proto
示例#15
0
def AddAttentionSummaryBatchMajor(name,
                                  attention_tensors,
                                  src_paddings,
                                  tgt_paddings,
                                  transcripts=None,
                                  max_outputs=3):
    """Adds an image summary showing the attention probability matrix and state.

  As opposed to AddAttentionSummary() takes all tensors with batch dimension in
  axis 0.

  Args:
    name: Summary name.
    attention_tensors: A list of 3D tensors shaped [batch_size, target_len,
      source_len] where attention[b, i, j] is the probability for the i-th
      output attending to the j-th input for element b in the batch.
    src_paddings: A tensor of binary paddings shaped [batch, source_len] for the
      source sequence. Or a list of tensors of the same length as
      attention_tensors with a separate paddings for each entry in
      attention_tensors.
    tgt_paddings: A tensor of binary paddings shaped [batch, target_len] for the
      target sequence. Or a list of tensors of the same length as
      attention_tensors with a separate paddings for each entry in
      attention_tensors.
    transcripts: Optional, transcripts shaped [batch, source_len] for the source
      sequence.
    max_outputs: Integer maximum number of elements of the batch to plot.
  """
    def VerifyLen(paddings):
        length = len(paddings) if isinstance(paddings, list) else 1
        if length != 1 and length != len(attention_tensors):
            raise ValueError('Bad length of paddings list {}'.format(length))

    VerifyLen(src_paddings)
    VerifyLen(tgt_paddings)

    # Verify shapes.
    for i, attention_tensor in enumerate(attention_tensors):
        src, tgt = src_paddings, tgt_paddings
        src = src[0 if len(src) == 1 else i] if isinstance(src, list) else src
        tgt = tgt[0 if len(tgt) == 1 else i] if isinstance(tgt, list) else tgt
        tgt_shape = py_utils.GetShape(tgt)

        if not tf.executing_eagerly():
            attention_tensor_name = attention_tensor.name
        else:
            attention_tensor_name = f'[eager]_{name}_{i}'

        attention_tensors[i] = tf.identity(
            py_utils.with_dependencies([
                py_utils.assert_equal(
                    py_utils.GetShape(attention_tensor), tgt_shape[:2] +
                    [py_utils.GetShape(src)[1]] + tgt_shape[2:])
            ], attention_tensor), re.sub(':.*$', '', attention_tensor_name))

    if not _ShouldAddSummary():
        return

    def ToLengths(paddings):
        paddings = paddings if isinstance(paddings, list) else [paddings]
        return [SequenceLength(p) for p in paddings]

    def Get(lengths, i):
        return lengths[0 if len(lengths) == 1 else i]

    src_lens = ToLengths(src_paddings)
    tgt_lens = ToLengths(tgt_paddings)

    with plot.MatplotlibFigureSummary(name + '/Attention',
                                      max_outputs=max_outputs,
                                      gridspec_kwargs={'hspace': 0.3}) as fig:
        for n, atten in enumerate(attention_tensors):
            # Diagnostic metric that decreases as attention picks up.
            max_entropy = tf.math.log(tf.cast(Get(src_lens, n), tf.float32))
            max_entropy = tf.expand_dims(tf.expand_dims(max_entropy, -1), -1)
            atten_normalized_entropy = -atten * tf.math.log(
                atten + 1e-10) / max_entropy
            scalar(name + '/Attention/average_normalized_entropy/%d' % n,
                   tf.reduce_mean(atten_normalized_entropy))
            args = [atten, Get(src_lens, n), Get(tgt_lens, n)]
            if transcripts is not None and n == 0:
                args.append(transcripts)

            if not tf.executing_eagerly():
                atten_name = atten.name
            else:
                atten_name = f'[eager]_{name}_{n}'

            fig.AddSubplot(args,
                           TrimPaddingAndPlotAttention,
                           title=atten_name,
                           xlabel='Input',
                           ylabel='Output')