コード例 #1
0
 def __init__(self, params):
   super(SeqLayer, self).__init__(params)
   p = self.params
   assert p.name
   num_cells = len(p.cell_tpl)
   self._before_layers = []
   self._cells = []
   before_tpl_device = ''
   cell_devices = [''] * num_cells
   if py_utils.use_tpu():
     cluster = self.cluster
     before_tpl_device = cluster.WorkerDeviceInModelSplit(0)
     cell_devices = [
         cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells)
     ]
   for l in p.before_tpl:
     with tf.device(before_tpl_device):
       assert l.name
       self.CreateChild(l.name, l)
       self._before_layers.append((l.name, self.children[l.name]))
   for i, l in enumerate(p.cell_tpl):
     with tf.device(cell_devices[i]):
       assert l.name
       self.CreateChild(l.name, l)
       self._cells.append((l.name, self.children[l.name]))
コード例 #2
0
    def FProp(self, theta, *args):
        """FProp through multiple devices in the split.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: A tuple of Tensors (one or more). Every tensor's first dimension is
        the same (the batch dimension).

    Returns:
      The sub layer's output.
    """
        p = self.params
        with tf.name_scope(p.name):
            assert all(isinstance(x, tf.Tensor) for x in args)
            cluster = self.cluster
            num = cluster.num_devices_per_split
            if num == 1:
                return self.sub.FProp(theta.sub, *args)
            inps = py_utils.SplitRecursively(list(args), num, axis=0)
            outs = []
            for i, xs in enumerate(inps):
                device = cluster.WorkerDeviceInModelSplit(i)
                tf.logging.info('%d on device %s', i, device)
                with tf.device(device):
                    ys = self.sub.FProp(theta.sub, *xs)
                    if isinstance(ys, tuple):
                        outs += [list(ys)]
                    else:
                        outs += [ys]  # ys is a single tensor
            ret = py_utils.ConcatRecursively(outs, axis=0)
            if isinstance(ret, list):
                return tuple(ret)
            else:
                return ret  # ys is a single tensor
コード例 #3
0
  def FProp(self, theta, *args):
    """Round-robin every children cells in cell_tpl among worker devices.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Input args

    Returns:
      A list contains one tensor of [batch_size, feature_height, feature_width,
        channel].
    """
    num_layers = len(self.params.cell_tpl)
    cluster = self.cluster

    for (name, l) in self._before_layers:
      l_theta = theta[name]
      args = _ToTuple(args)
      args = l.FProp(l_theta, *args)
    for i in range(num_layers):
      with tf.device(cluster.WorkerDeviceInModelSplit(i)):
        cell_name, cell = self._cells[i]
        args = _ToTuple(args)
        args = cell.FProp(theta[cell_name], *args)

    return args
コード例 #4
0
    def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
コード例 #5
0
    def TpuDequeueBatch(self):
        """Create TPU dequeue ops.

    This should only be called within a TPU context.

    Returns:
    - A NestedMap of the input batch.
    """
        assert self._tpu_queues, 'CreateTpuEnqueueOps must be called first.'
        with tf.device(tf.tpu.core(0)):
            # Note that the dequeue_tuple op on the TPU core
            # only cares about the shape/types being dequeued
            # which is why this is hard-coded to the first Queue.
            tensors = self._tpu_queues[0].generate_dequeue_op()
        return self._batch.Pack(tensors)
コード例 #6
0
def CollectVarHistogram(vs_gs):
  """Adds histogram summaries for variables and gradients."""

  for name, (var, grad) in vs_gs.FlattenItems():
    name = py_utils.SanitizeScopeKey(name)
    with tf.device(var.device), tf.name_scope(name + '/summary'):
      if isinstance(grad, tf.IndexedSlices):
        var = tf.gather(var, grad.indices)
        grad = grad.values
      if var.dtype.is_complex:
        var = tf.abs(var)
        grad = tf.abs(grad)

    histogram('var_hist/' + name, var)
    histogram('grad_hist/' + name, grad)
コード例 #7
0
 def Recv(self):
     """Receives a tensor from the channel."""
     if self._send_tpu_core == -1:
         received = tf.raw_ops.Recv(tensor_type=self._dtype,
                                    tensor_name=self._name,
                                    send_device=self._send_device,
                                    send_device_incarnation=0,
                                    recv_device=self._recv_device)
         received.set_shape(self._shape)
         return received
     else:
         with tf.device(self._recv_device):
             return xla.recv(self._dtype,
                             tensor_name=self._name,
                             shape=self._shape,
                             name="Recv_" + self._name)
コード例 #8
0
 def Send(self, tensor):
     """Sends a tensor through the channel."""
     assert tensor.dtype == self._dtype
     assert not self._send_called, ("Send called multiple times for %s" %
                                    self._name)
     self._send_called = True
     if self._send_tpu_core == -1:
         return tf.raw_ops.Send(tensor=tensor,
                                tensor_name=self._name,
                                send_device=self._send_device,
                                send_device_incarnation=0,
                                recv_device=self._recv_device)
     else:
         with tf.device(self._send_device):
             return xla.send(tensor,
                             tensor_name=self._name,
                             name="Send_" + self._name)
コード例 #9
0
  def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
               tf_master, **kwargs):
    """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
          if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
    super(ExecutorTpu, self).__init__(train_cfg, model_task_name, logdir,
                                      tf_master, **kwargs)

    self._cluster_def = self._cluster.worker_cluster_def

    # There is a single Executor task
    assert self._cluster.num_replicas == 1
    data_parallelism = self._cluster.num_splits_per_client

    assert data_parallelism
    num_devices_per_split = self._cluster.num_devices_per_split
    tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                    data_parallelism, num_devices_per_split)

    self.task_scheduler = None
    self._checkpoint_dir = os.path.join(logdir, 'train')

    self._variable_renaming_rules = []

    self._ml_perf = None

    # If this is a multi-task model, grab the params for the TaskScheduler.
    if issubclass(train_cfg.cls, base_model.SingleTaskModel):
      tf.logging.info('single_task_model')
      assert len(ps_params_dict) == 1
      self._model_task_name = list(ps_params_dict.keys())[0]
      self._single_task_mode = True
    elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
      tf.logging.info('multi_task_model')

      if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel):
        self._variable_renaming_rules = train_cfg.variable_renaming_rules

      if train_cfg.task_schedule is None:
        task_schedule_params = task_scheduler.ConstantScheduler.Params()
        task_schedule_params.task_probs = sorted(
            list(train_cfg.task_probs.IterParams()))
      else:
        task_schedule_params = train_cfg.task_schedule
      self.task_scheduler = task_schedule_params.Instantiate()
      self._single_task_mode = False
    else:
      tf.logging.fatal(
          'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
          train_cfg.cls)

    tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

    self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'params.txt')
    self._program_schedule_dict = {}
    self._programs = []

    for task_string, program_schedule_params in ps_params_dict.items():
      program_schedule_params.logdir = logdir
      program_schedule_params.num_splits_per_client = data_parallelism
      program_schedule_params.task_name = task_string
      ps = program_schedule_params.Instantiate()
      self._program_schedule_dict[task_string] = ps
      tf.logging.info('program_schedule_params: %s',
                      program_schedule_params.ToText())
      self._programs += ps.Programs()
      if program_schedule_params.ml_perf.benchmark_name is not None:
        self._ml_perf = program_schedule_params.ml_perf

    tf.logging.info('num_programs: %d', len(self._programs))

    if self._ml_perf is not None:
      self._ml_perf_log = True
      mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name)
    else:
      self._ml_perf_log = False

    # BaseRunner legacy
    self.enqueue_ops = None

    @py_utils.RetryOnTransientTfError()
    def _WaitTillInit():
      """Wait until the model is ready."""
      try:
        with self._graph.as_default(), self._GetSession(
            cluster_def=self._cluster_def,
            disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor
        ) as sess:
          topology = sess.run(
              tf.tpu.initialize_system(embedding_config=None, job=None))
          device_assignment = device_assignment_lib.device_assignment(
              topology,
              computation_shape=py_utils.ComputationShape(
                  num_devices_per_split),
              num_replicas=data_parallelism)
          py_utils.SetTpuDeviceAssignment(device_assignment)
          tf.logging.info('device_assignment.core_assignment: %s',
                          str(device_assignment.core_assignment))
          tf.logging.info('device_assignment.topology.device_coordinates: %s',
                          str(device_assignment.topology.device_coordinates))
      except py_utils.transient_tf_errors as e:
        tf.logging.info('TPU initialization failed: %s', e)
        raise

    if self._ml_perf_log:
      mlp_log.mlperf_print(key='cache_clear', value=True)
      mlp_log.mlperf_print(key='init_start', value=None)
    _WaitTillInit()

    with self._graph.as_default(), tf.container(self._container_id):
      tf.logging.info('self._cluster.job_spec.name: %s',
                      self._cluster.job_spec.name)
      with self._cluster, tf.device(
          self._cluster.job_spec.name if not FLAGS.cluster_placer_in_executor
          else self._cluster.GetPlacer()):
        with py_utils.VariableRenameScope(self._variable_renaming_rules):
          _ = py_utils.GetOrCreateGlobalStepVar()
          for program in self._programs:
            program.BuildTpuSubgraph()
        for program in self._programs:
          program.SetStatusMessageFn(self._SetStatusMessage)
          program.CreateCheckpointer()
        self._initialize_tables = tf.tables_initializer()
        self._initialize_local_vars = tf.local_variables_initializer()

        self.save_only_checkpointer = checkpointer.Checkpointer(
            self._checkpoint_dir,
            model=None,
            train_params=train_cfg.train,
            save_only=True)
コード例 #10
0
 def _DecoderDevice(self):
     """Returns the device to run the decoder computation."""
     return tf.device('')
コード例 #11
0
 def SplitInputBatch(self, num_splits):
     with tf.device(self.cluster.input_device):
         return super(_UseInputDevice,
                      self).SplitInputBatch(num_splits)
コード例 #12
0
 def __init__(self, params):
     with tf.device(self.cluster.input_device):
         super(_UseInputDevice, self).__init__(params)
コード例 #13
0
  def FProp(self, theta, *args):
    """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
    # TODO(huangyp): handle optional None inputs.
    p = self.params
    if self.do_eval:
      outputs = copy.copy(args)
      for (name, l) in self._before_layers + self._cells:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      return outputs

    num_cells = len(p.cell_tpl)
    cluster = self.cluster

    # Compute shapes of input and output tensors.
    input_shapes = self._get_input_shapes(*args)
    state_dtype = self._get_state_dtype(*args)
    state_shapes = self._CalculateOutputShapes(input_shapes)
    tf.logging.info('state_shapes={}'.format(state_shapes))

    def GetCellFn(i):
      """Get the ith feature extraction layer."""

      def CellFn(theta, state0, inputs):
        """A cell fn is exectued inside of StackedRecurrent."""
        del state0

        def _FPropInputSetShape(name, t_shape):
          if t_shape is None:
            return None
          inputs[name].set_shape(t_shape.ToTensorShape().as_list())
          return inputs[name]

        if p.nested_map_fprop:
          # pylint: disable=protected-access
          fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape)
          # pylint: enable=protected-access
        else:
          fprop_inputs = []
          for input_idx, input_shape in enumerate(state_shapes[i]):
            name = 's{}'.format(input_idx)
            fprop_inputs.append(_FPropInputSetShape(name, input_shape))

        with py_utils.RemoveAssertContext(remove=True):
          with CellFnFPropOpReplacementWrapper():
            tf.logging.info('cell {} input {}'.format(i, fprop_inputs))
            mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
            SetOverWriteGlobalStep(mb_tensor)
            _, cell = self._cells[i]
            fprop_inputs = _ToTuple(fprop_inputs)
            outputs = cell.FProp(theta, *fprop_inputs)

        if p.nested_map_fprop:
          assert py_utils.IsCompatible(outputs, state_shapes[i + 1])
          state1 = outputs.Filter(lambda x: x is not None)
        else:
          state1 = py_utils.NestedMap()
          outputs = _ToTuple(outputs)
          assert len(outputs) == len(state_shapes[i + 1])
          for output_idx in range(len(outputs)):
            if outputs[output_idx] is not None:
              name = 's{}'.format(output_idx)
              state1[name] = outputs[output_idx]
        state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
        return state1, py_utils.NestedMap()

      return CellFn

    cell_fns = []
    accumulator_layers = []
    thetas = []
    init_states = []
    devices = []
    for cell_idx in range(num_cells):
      cell_name, cell = self._cells[cell_idx]
      accumulator_layers.append(cell)
      cell_fns.append(GetCellFn(cell_idx))
      thetas.append(theta[cell_name])

      def _TfZeros(t_shape):
        if t_shape is None:
          return None
        return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype)

      if p.nested_map_fprop:
        init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1])
        init_state = init_state.Filter(lambda x: x is not None)
      else:
        init_state = py_utils.NestedMap()
        for output_idx, state in enumerate(state_shapes[cell_idx + 1]):
          state = _TfZeros(state)
          if state is not None:
            name = 's{}'.format(output_idx)
            init_state[name] = state
      init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
      init_states.append(init_state)

      devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

    cell_grads = [None] * num_cells
    cell_outs = [lambda x: x] * num_cells
    cell_out_grads = [lambda x: x] * num_cells

    with tf.device(devices[0]):
      previous = _ToTuple(args)
      for (name, l) in self._before_layers:
        previous = l.FProp(theta[name], *previous)
        previous = _ToTuple(previous)

      def _StackAndSplit(x):
        # Split tensors into microbatches.
        if x is None:
          return None
        return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim))

      if p.nested_map_fprop:
        inputs = py_utils.Transform(_StackAndSplit, previous[0])
        inputs = inputs.Filter(lambda x: x is not None)
      else:
        inputs = py_utils.NestedMap()
        for output_idx, output_tensor in enumerate(previous):
          output_tensor = _StackAndSplit(output_tensor)
          if output_tensor is not None:
            name = 's{}'.format(output_idx)
            inputs[name] = output_tensor
      gs_tensor = py_utils.GetGlobalStep()
      inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
          tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
          for t in range(p.num_micro_batches)
      ])
    tf.logging.info('pipeline input = {}'.format(inputs))
    output_state, _ = recurrent.StackedRecurrent(
        devices=devices,
        cell_fns=cell_fns,
        cell_grads=cell_grads,
        cell_outs=cell_outs,
        cell_out_grads=cell_out_grads,
        thetas=thetas,
        init_states=init_states,
        inputs=inputs,
        accumulator_layers=accumulator_layers,
        unused_acc_state=True)

    with tf.device(devices[-1]):

      def _ReshapeRetVal(name, t_shape):
        """Restore shape for tensors in microbatches."""
        if t_shape is None:
          return None
        output_tensor = output_state[name]
        if p.batch_dim != 0:
          perm = list(range(1, p.batch_dim + 1)) + [0]
          perm += list(range(p.batch_dim + 1, t_shape.rank + 1))
          output_tensor = tf.transpose(output_tensor, perm=perm)
        output_shape = t_shape.ToTensorShape().as_list()
        output_shape[p.batch_dim] *= p.num_micro_batches
        output_tensor = tf.reshape(output_tensor, output_shape)
        return output_tensor

      # Construct the final return values from output_state.
      if p.nested_map_fprop:
        # pylint: disable=protected-access
        output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal)
        # pylint: enable=protected-access
      else:
        output_tensors = []
        for output_idx, state_shape in enumerate(state_shapes[-1]):
          output_name = 's{}'.format(output_idx)
          output_tensor = _ReshapeRetVal(output_name, state_shape)
          output_tensors.append(output_tensor)
        if len(output_tensors) == 1:
          output_tensors = output_tensors[0]
        else:
          output_tensors = tuple(output_tensors)
      tf.logging.info('pipeline output = {}'.format(output_tensors))
      return output_tensors
コード例 #14
0
    def Export(cls,
               model_cfg,
               model_task_name=None,
               device_options=InferenceDeviceOptions(
                   device='',
                   retain_device_placement=False,
                   var_options=None,
                   gen_init_op=True,
                   dtype_override=None),
               freeze_checkpoint=None,
               freeze_defaults=False,
               export_path=None,
               subgraph_filter=None,
               random_seed=None,
               disable_packed_input=True):
        """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
        assert issubclass(model_cfg.cls, base_model.BaseModel)

        # Disable assertions unless user explicitly enables it.
        if FLAGS['enable_asserts'].using_default_value:
            FLAGS.enable_asserts = False

        # TODO(laurenzo): Work out how much we need to specify here in terms of
        # cluster configuration.
        cls._SetClusterParams(model_cfg.cluster, device_options)

        # Configure the model.
        model_cfg.random_seed = random_seed
        model_cfg.is_inference = True

        if disable_packed_input:

            def _DisablePackedInput(task):
                if (_ParamExists(task, 'encoder')
                        and _ParamExists(task.encoder, 'packed_input')):
                    task.encoder.packed_input = False
                if (_ParamExists(task, 'decoder')
                        and _ParamExists(task.decoder, 'packed_input')):
                    task.decoder.packed_input = False

            if issubclass(model_cfg.cls, base_model.MultiTaskModel):
                for _, task_param in model_cfg.task_params.IterParams():
                    _DisablePackedInput(task_param)
            else:
                _DisablePackedInput(model_cfg.task)

        tf.logging.info('Model %s params:', model_cfg.name)
        for line in model_cfg.ToText().split('\n'):
            tf.logging.info('%s', line)

        # Instantiate the graph.
        graph = tf.Graph()
        with graph.as_default():
            tf.random.set_seed(random_seed)
            cluster = model_cfg.cluster.Instantiate()
            device = cluster.GetPlacer()
            tpu_const_scope = _DummyScope()
            if (IsTpu(device_options)
                    and device_options.var_options == 'AS_CONSTANTS'):
                # Do not specify devices for variables if we are marking them as
                # constants.
                device = ''
                tpu_const_scope = ConstGuaranteeScope()

            with cluster, tf.device(device), tpu_const_scope:

                bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
                    device_options)

                if bfloat16_override:
                    py_utils.UpdateDtype(model_cfg, tf.bfloat16)
                    py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

                # Hard-code TPU-related flags prior to instantiating model.
                old_enable_asserts = FLAGS.enable_asserts
                old_xla_device = FLAGS.xla_device
                if IsTpu(device_options):
                    FLAGS.enable_asserts = False
                    FLAGS.xla_device = 'tpu'

                # Ensure the global_step variable is created.
                global_step_var = py_utils.GetOrCreateGlobalStepVar()
                global_step = tf.identity(global_step_var,
                                          name='global_step_tensor')

                with py_utils.GlobalStepContext(global_step):
                    try:
                        mdl = model_cfg.Instantiate()
                        variables_to_restore = (_MakeVariableDictionary(
                            tf.global_variables()) if not mdl.ema else
                                                mdl.ema.variables_to_restore(
                                                    mdl.variables_for_ema))

                        if bfloat16_override:
                            saver_var_spec = (
                                bfloat16_variables.
                                get_saver_spec_for_variables_with_bf16_overrides(
                                    variables_to_restore))
                        else:
                            saver_var_spec = variables_to_restore

                        saver = tf.train.Saver(saver_var_spec)
                        tf.variables_initializer(tf.global_variables(),
                                                 name='init_all_variables')
                        if IsTpu(
                                device_options) and device_options.gen_init_op:
                            tf.group(tf.tpu.initialize_system(),
                                     name='tpu_init_op')

                        model_task = mdl.GetTask(model_task_name)

                        inference_graph_proto = inference_graph_pb2.InferenceGraph(
                        )
                        subgraphs_proto = model_task.Inference()
                        if isinstance(subgraphs_proto, dict):
                            subgraphs_proto = ConvertSubgraphDictToProto(
                                subgraphs_proto)
                        for name, subgraph in subgraphs_proto.subgraphs.items(
                        ):
                            if not subgraph_filter or name in subgraph_filter:
                                inference_graph_proto.subgraphs[name].CopyFrom(
                                    subgraph)

                        # Add a table init op and global variable init op to the graph.
                        # Tables can be declared anywhere in the graph, so this op has to be
                        # added last.
                        tf.tables_initializer(name='init_all_tables')
                    finally:
                        # Reset TPU-related flags after model instantiation.
                        FLAGS.enable_asserts = old_enable_asserts
                        FLAGS.xla_device = old_xla_device

        tf.logging.info('Graph contains ops: %r',
                        [op.name for op in graph.get_operations()])

        inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

        # Freezing.
        if freeze_defaults or freeze_checkpoint:
            output_op_names = GetOutputOpNames(graph,
                                               inference_graph_proto,
                                               preserve_colocation_nodes=False)
            if cls._DeviceSupportsFreezing(device_options):
                raise ValueError(
                    'freeze_checkpoint cannot be used with device ' +
                    device_options.device)
            if freeze_checkpoint:
                tf.logging.info('Freezing graph from checkpoint: %s',
                                freeze_checkpoint)
                graph_def = _FreezeGraphFromCheckpoint(graph, saver,
                                                       freeze_checkpoint,
                                                       output_op_names)
            elif freeze_defaults:
                tf.logging.info('Default initializing graph and freezing.')
                graph_def = _FreezeDefaults(graph, output_op_names)
        else:
            output_op_names = GetOutputOpNames(graph, inference_graph_proto)

            # Prune the graph to just the parts we need.
            # To support restoring, we have to not prune out the restore node.
            output_op_names.append('init_all_tables')
            output_op_names.append('init_all_variables')
            output_op_names.append('save/control_dependency')
            output_op_names.append('save/restore_all')
            if IsTpu(device_options) and device_options.gen_init_op:
                output_op_names.append('tpu_init_op')
            graph_def = graph.as_graph_def()
            tf.logging.info('Pruning graph to output ops: %r', output_op_names)
            graph_def = tf.graph_util.extract_sub_graph(
                graph_def, output_op_names)

        if not device_options.retain_device_placement:
            # Clear the device so that the runtime can choose.
            tf.logging.info('Clearing device placement for: %s',
                            device_options.device)
            for node in graph_def.node:
                node.ClearField('device')
            for function in graph_def.library.function:
                for node_def in function.node_def:
                    node_def.ClearField('device')

        inference_graph_proto.graph_def.CopyFrom(graph_def)

        if export_path:
            with tf.io.gfile.GFile(export_path, 'w') as f:
                f.write(text_format.MessageToString(inference_graph_proto))
        return inference_graph_proto