コード例 #1
0
    def Export(cls,
               model_cfg,
               model_task_name=None,
               device_options=InferenceDeviceOptions(
                   device='',
                   retain_device_placement=False,
                   var_options=None,
                   gen_init_op=True,
                   dtype_override=None),
               freeze_checkpoint=None,
               freeze_defaults=False,
               export_path=None,
               subgraph_filter=None,
               random_seed=None,
               disable_packed_input=True):
        """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
        assert issubclass(model_cfg.cls, base_model.BaseModel)

        # Disable assertions unless user explicitly enables it.
        if FLAGS['enable_asserts'].using_default_value:
            FLAGS.enable_asserts = False

        # TODO(laurenzo): Work out how much we need to specify here in terms of
        # cluster configuration.
        cls._SetClusterParams(model_cfg.cluster, device_options)

        # Configure the model.
        model_cfg.random_seed = random_seed
        model_cfg.is_inference = True

        if disable_packed_input:

            def _DisablePackedInput(task):
                if (_ParamExists(task, 'encoder')
                        and _ParamExists(task.encoder, 'packed_input')):
                    task.encoder.packed_input = False
                if (_ParamExists(task, 'decoder')
                        and _ParamExists(task.decoder, 'packed_input')):
                    task.decoder.packed_input = False

            if issubclass(model_cfg.cls, base_model.MultiTaskModel):
                for _, task_param in model_cfg.task_params.IterParams():
                    _DisablePackedInput(task_param)
            else:
                _DisablePackedInput(model_cfg.task)

        tf.logging.info('Model %s. Params: %s', model_cfg.name,
                        model_cfg.ToText())

        # Instantiate the graph.
        graph = tf.Graph()
        with graph.as_default():
            tf.set_random_seed(random_seed)
            cluster = model_cfg.cluster.Instantiate()
            device = cluster.GetPlacer()
            tpu_const_scope = _DummyScope()
            if (IsTpu(device_options)
                    and device_options.var_options == 'AS_CONSTANTS'):
                # Do not specify devices for variables if we are marking them as
                # constants.
                device = ''
                tpu_const_scope = ConstGuaranteeScope()

            with cluster, tf.device(device), tpu_const_scope:

                bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
                    device_options)

                if bfloat16_override:
                    py_utils.UpdateDtype(model_cfg, tf.bfloat16)
                    py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

                # Hard-code TPU-related flags prior to instantiating model.
                old_enable_asserts = FLAGS.enable_asserts
                old_xla_device = FLAGS.xla_device
                if IsTpu(device_options):
                    FLAGS.enable_asserts = False
                    FLAGS.xla_device = 'tpu'

                # Ensure the global_step variable is created.
                global_step_var = py_utils.GetOrCreateGlobalStepVar()
                global_step = tf.identity(global_step_var,
                                          name='global_step_tensor')

                with py_utils.GlobalStepContext(global_step):
                    try:
                        mdl = model_cfg.Instantiate()
                        variables_to_restore = (_MakeVariableDictionary(
                            tf.global_variables()) if not mdl.ema else
                                                mdl.ema.variables_to_restore(
                                                    mdl.variables_for_ema))

                        if bfloat16_override:
                            saver_var_spec = (
                                bfloat16_variables.
                                get_saver_spec_for_variables_with_bf16_overrides(
                                    variables_to_restore))
                        else:
                            saver_var_spec = variables_to_restore

                        saver = tf.train.Saver(saver_var_spec)
                        tf.variables_initializer(tf.global_variables(),
                                                 name='init_all_variables')
                        if IsTpu(
                                device_options) and device_options.gen_init_op:
                            tf.group(tf.tpu.initialize_system(),
                                     name='tpu_init_op')

                        model_task = mdl.GetTask(model_task_name)

                        inference_graph_proto = inference_graph_pb2.InferenceGraph(
                        )
                        subgraphs_proto = model_task.Inference()
                        if isinstance(subgraphs_proto, dict):
                            subgraphs_proto = ConvertSubgraphDictToProto(
                                subgraphs_proto)
                        for name, subgraph in subgraphs_proto.subgraphs.items(
                        ):
                            if not subgraph_filter or name in subgraph_filter:
                                inference_graph_proto.subgraphs[name].CopyFrom(
                                    subgraph)

                        # Add a table init op and global variable init op to the graph.
                        # Tables can be declared anywhere in the graph, so this op has to be
                        # added last.
                        tf.tables_initializer(name='init_all_tables')
                    finally:
                        # Reset TPU-related flags after model instantiation.
                        FLAGS.enable_asserts = old_enable_asserts
                        FLAGS.xla_device = old_xla_device

        tf.logging.info('Graph contains ops: %r',
                        [op.name for op in graph.get_operations()])

        inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

        # Freezing.
        if freeze_defaults or freeze_checkpoint:
            output_op_names = GetOutputOpNames(graph,
                                               inference_graph_proto,
                                               preserve_colocation_nodes=False)
            if cls._DeviceSupportsFreezing(device_options):
                raise ValueError(
                    'freeze_checkpoint cannot be used with device ' +
                    device_options.device)
            if freeze_checkpoint:
                tf.logging.info('Freezing graph from checkpoint: %s',
                                freeze_checkpoint)
                graph_def = _FreezeGraphFromCheckpoint(graph, saver,
                                                       freeze_checkpoint,
                                                       output_op_names)
            elif freeze_defaults:
                tf.logging.info('Default initializing graph and freezing.')
                graph_def = _FreezeDefaults(graph, output_op_names)
        else:
            output_op_names = GetOutputOpNames(graph, inference_graph_proto)

            # Prune the graph to just the parts we need.
            # To support restoring, we have to not prune out the restore node.
            output_op_names.append('init_all_tables')
            output_op_names.append('init_all_variables')
            output_op_names.append('save/control_dependency')
            output_op_names.append('save/restore_all')
            if IsTpu(device_options) and device_options.gen_init_op:
                output_op_names.append('tpu_init_op')
            graph_def = graph.as_graph_def()
            tf.logging.info('Pruning graph to output ops: %r', output_op_names)
            graph_def = tf.graph_util.extract_sub_graph(
                graph_def, output_op_names)

        if not device_options.retain_device_placement:
            # Clear the device so that the runtime can choose.
            tf.logging.info('Clearing device placement for: %s',
                            device_options.device)
            for node in graph_def.node:
                node.ClearField('device')
            for function in graph_def.library.function:
                for node_def in function.node_def:
                    node_def.ClearField('device')

        inference_graph_proto.graph_def.CopyFrom(graph_def)

        if export_path:
            with tf.gfile.Open(export_path, 'w') as f:
                f.write(text_format.MessageToString(inference_graph_proto))
        return inference_graph_proto
コード例 #2
0
ファイル: executor.py プロジェクト: ankitshah009/lingvo
    def __init__(self, train_cfg, ps_params_dict, *args, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      *args: List args to pass through to BaseRunner.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        if py_utils.IsEagerMode():
            assert tf.executing_eagerly()
            tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}')

            # Connect to the TPU runtime.
            resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):])
            tf.config.experimental_connect_to_cluster(resolver)

        super().__init__(train_cfg, *args, **kwargs)

        data_parallelism = self._cluster.num_splits_per_client
        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(self._logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        self._WriteToLog(
            text_format.MessageToString(train_cfg.ToProto(), as_utf8=True),
            self._checkpoint_dir, 'trainer_params.pbtxt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                if py_utils.IsEagerMode():
                    topology = tf.tpu.experimental.initialize_tpu_system(
                        resolver)
                else:
                    # tpu.initialize_system() is called with None as embedding_config, as
                    # embedding_config is not available yet. Later in _Loop, it is called
                    # with the correct embedding_config. Since it cannot be called twice
                    # in the same graph with different embedding_config, we use a
                    # dummy_graph here.
                    dummy_graph = tf.Graph()
                    with dummy_graph.as_default():
                        tpu_initialize_system_op = tf.tpu.initialize_system(
                            embedding_config=None, job=job)

                    with self._GetSession(graph=dummy_graph) as sess:
                        topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_computation_shape is None:
                    computation_shape = py_utils.ComputationShape(
                        num_devices_per_split, topology)
                else:
                    computation_shape = train_cfg.train.tpu_computation_shape
                    assert num_devices_per_split == np.prod(computation_shape)

                if train_cfg.train.tpu_device_order_mode is None:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism)
                else:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(self.device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(self.device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(self.device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []
        self._ckpt_programs = []

        self._checkpoint_to_load = None
        with self._cluster:
            # Create the ExponentialMovingAverage singleton shared by all programs, if
            # applicable.
            ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var)
            for task_string, program_schedule_params in ps_params_dict.items():
                program_schedule_params.logdir = self._logdir
                program_schedule_params.num_splits_per_client = data_parallelism
                program_schedule_params.task_name = task_string
                # If the model was created above, we'll inject it here as a
                # shared_model.
                ps = program_schedule_params.Instantiate(
                    shared_model=shared_model,
                    trial=self._trial,
                    ema=ema,
                    tf_master=self._tf_master)
                self._program_schedule_dict[task_string] = ps
                tf.logging.info('program_schedule_params: %s',
                                program_schedule_params.ToText())
                self._programs += ps.Programs()
                if ps.train_program:
                    self._ckpt_programs.append(ps.train_program)
                else:
                    self._ckpt_programs += ps.Programs()
                if program_schedule_params.ml_perf.benchmark_name is not None:
                    self._ml_perf = program_schedule_params.ml_perf
                if ('checkpoint_to_load' in program_schedule_params
                        and program_schedule_params.checkpoint_to_load):
                    if (self._checkpoint_to_load
                            and (self._checkpoint_to_load !=
                                 program_schedule_params.checkpoint_to_load)):
                        raise ValueError(
                            f'Multiple values found for checkpoint_to_load: '
                            f'{self._checkpoint_to_load}, '
                            f'{program_schedule_params.checkpoint_to_load}.')
                    self._checkpoint_to_load = program_schedule_params.checkpoint_to_load

        tf.logging.info('num_programs: %d', len(self._programs))

        # When running in a vizier trainer, the executor reports infeasiable runs
        # in case of errors. The programs report metrics and normal completions.
        for program in self._programs:
            if program._should_report_metrics:
                self._should_report_metrics = True

        with self._cluster, tf.container(
                self._container_id), contextlib.ExitStack() as stack:
            if not py_utils.IsEagerMode():
                stack.enter_context(self._graph.as_default())

                if FLAGS.use_tpu_mirrored_vars:
                    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
                        FLAGS.tf_master,
                        job_name=FLAGS.worker_job[len('/job:'):])
                    self._tpu_strategy = tf.distribute.experimental.TPUStrategy(
                        resolver, device_assignment=self.device_assignment)
                    stack.enter_context(self._tpu_strategy.scope())
                    stack.enter_context(
                        tpu_strategy._TPUReplicaContext(self._tpu_strategy))
                else:
                    stack.enter_context(tf.device(self._cluster.GetPlacer()))

            if FLAGS.pdb_on_exception:
                stack.enter_context(pdb_wrapper.catch_post_mortem())
            with py_utils.VariableStore(), py_utils.VariableRenameScope(
                    self._variable_renaming_rules):
                # `BuildTpuSubgraph` has to be called before checkpoint restore, so that
                # the optimizer slot variables are guaranteed to be initialized before
                # they get loaded. Otherwise, the optimizers' slot variables will not
                # be properly loaded when V1 checkpoint is used.
                for program in self._programs:
                    program.BuildTpuSubgraph()
                    py_utils.ClearTpuSummaryTensors()

            if not py_utils.IsEagerMode():
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

            checkpointer_models = [
                program.GetModel() for program in self._ckpt_programs
            ]

            if py_utils.IsEagerMode():
                if FLAGS.use_v2_checkpoints_in_eager:
                    self._checkpointer = checkpointer.EagerCheckpointerV2(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
                else:
                    self._checkpointer = checkpointer.EagerCheckpointerV1(
                        self._checkpoint_dir,
                        models=checkpointer_models,
                        init_op=None,
                        train_params=train_cfg.train,
                        save_only=False)
            else:
                self._checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    models=checkpointer_models,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=False)

            for program in self._programs:
                program.SetStatusMessageFn(self._SetStatusMessage)

            tpu_embedding_collection = (
                tpu_embedding_layers.TpuEmbeddingCollection.Get())
            self._load_ops = tpu_embedding_collection.load_ops
            self._retrieve_ops = tpu_embedding_collection.retrieve_ops
            self._tpu_embedding = tpu_embedding_collection.tpu_embedding
コード例 #3
0
    def __init__(self, task_dict, program_schedule_params, model_task_name,
                 logdir, tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      task_dict: A dict of dataset_name -> task params.
      program_schedule_params: A ProgramSchedule params.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        # TODO(blee): fix this.
        train_params = task_dict['Train']
        super(ExecutorTpu, self).__init__(train_params, model_task_name,
                                          logdir, tf_master, **kwargs)

        # There is a single Executor task
        assert self._cluster.num_replicas == 1
        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        # Update run-time params
        program_schedule_params.task_dict = task_dict
        program_schedule_params.logdir = logdir
        program_schedule_params.num_splits_per_client = data_parallelism

        self._programs = []
        self._program_schedule = program_schedule_params.Instantiate()

        tf.logging.info('program_schedule_params: %s',
                        program_schedule_params.ToText())

        self._programs += self._program_schedule.Programs()

        # BaseRunner legacy
        self.enqueue_ops = None

        def ComputationShape(split_size):
            """Decides the computation shape based on the split_size."""
            computation_shape = None
            if split_size == 1:
                computation_shape = [1, 1, 1]
            elif split_size == 2:
                computation_shape = [1, 1, 2]
            elif split_size == 4:
                computation_shape = [1, 2, 2]
            elif split_size == 8:
                computation_shape = [2, 2, 2]
            elif split_size == 16:
                computation_shape = [4, 2, 2]
            else:
                assert False, (
                    'Model parallelism with %d devices is currently not'
                    ' supported.' % split_size)
            assert computation_shape is not None
            return computation_shape

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit():
            """Wait until the model is ready."""
            try:
                with self._GetSession() as sess:
                    topology = sess.run(
                        tf.tpu.initialize_system(embedding_config=None,
                                                 job=None))
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=ComputationShape(
                            num_devices_per_split),
                        num_replicas=data_parallelism)
                    py_utils.SetTpuDeviceAssignment(device_assignment)
                    tf.logging.info('device_assignment.core_assignment: %s',
                                    str(device_assignment.core_assignment))
                    tf.logging.info(
                        'device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        _WaitTillInit()

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(self._cluster.job_spec.name):
                for program in self._programs:
                    program.BuildTpuSubgraph()
                self.initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
コード例 #4
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
          if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super(ExecutorTpu, self).__init__(train_cfg, model_task_name, logdir,
                                          tf_master, **kwargs)

        self._cluster_def = self._cluster.worker_cluster_def

        # There is a single Executor task
        assert self._cluster.num_replicas == 1
        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            ps = program_schedule_params.Instantiate()
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()
            if program_schedule_params.ml_perf.benchmark_name is not None:
                self._ml_perf = program_schedule_params.ml_perf

        tf.logging.info('num_programs: %d', len(self._programs))

        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        # BaseRunner legacy
        self.enqueue_ops = None

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit():
            """Wait until the model is ready."""
            try:
                with self._graph.as_default(), self._GetSession(
                        cluster_def=self._cluster_def,
                        disable_meta_optimizer=FLAGS.
                        disable_meta_optimizer_in_executor) as sess:
                    topology = sess.run(
                        tf.tpu.initialize_system(embedding_config=None,
                                                 job=None))
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split),
                        num_replicas=data_parallelism)
                    py_utils.SetTpuDeviceAssignment(device_assignment)
                    tf.logging.info('device_assignment.core_assignment: %s',
                                    str(device_assignment.core_assignment))
                    tf.logging.info(
                        'device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        _WaitTillInit()

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(
                    self._cluster.job_spec.name if not FLAGS.
                    cluster_placer_in_executor else self._cluster.GetPlacer()):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    _ = py_utils.GetOrCreateGlobalStepVar()
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                for program in self._programs:
                    program.SetStatusMessageFn(self._SetStatusMessage)
                    program.CreateCheckpointer()
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()

                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    train_params=train_cfg.train,
                    save_only=True)
コード例 #5
0
    def testShampooWithMatrixShapedTensors(self):
        # Parameter matrix of size [4,2] would result in L_{t}, and R_{t} of
        # sizes [4, 4] and [2, 2]
        size = [4, 2]
        init_var_np = np.zeros(size)
        # Initialize gradient as random tensor.
        grad_np = np.random.rand(size[0], size[1])

        with tf.Session():
            global_step = tf.Variable(0, dtype=tf.int64)
            var = tf.Variable(init_var_np, dtype=tf.float32)
            grad = tf.constant(grad_np, dtype=tf.float32)

            opt = distributed_shampoo.DistributedShampoo(
                learning_rate=1.0,
                momentum=0.0,
                start_preconditioning_steps=0,
                synchronous_preconditioning=True,
                global_step=global_step)

            # Run a single step of gradient update.
            update = opt.apply_gradients(zip([grad], [var]),
                                         global_step=global_step)

            # Preconditioner computation and assignments to variables.
            compute_preconditioner_op = opt.invoke_async_preconditioner_computation(
                tf.cast(global_step, tf.int32))
            assign_preconditioners_to_vars_op = (
                opt.assign_preconditioner_to_host_vars())

            self.evaluate(tf.global_variables_initializer())
            tf.tables_initializer().run()

            init_val = self.evaluate(var)
            self.assertAllCloseAccordingToType(init_var_np, init_val)

            def np_power(mat_g, alpha, matrix_epsilon=1e-6):
                """Computes mat_g^alpha for a square symmetric matrix mat_g."""
                mat_for_svd = mat_g + np.eye(mat_g.shape[0]) * matrix_epsilon
                mat_u, diag_d, mat_v = np.linalg.svd(mat_for_svd,
                                                     full_matrices=True)
                diag_d = np.power(np.maximum(diag_d, matrix_epsilon), alpha)
                return np.dot(mat_u, np.dot(np.diag(diag_d), mat_v))

            def norm(val):
                return np.sqrt(np.sum(np.square(val)))

            # Run a step of preconditioner update.
            update.run()

            mat_g1 = np.dot(grad_np, grad_np.transpose())
            expected_mat_g1 = self.evaluate(
                opt.get_slot(var, 'mat_statistics_0'))
            self.assertAllCloseAccordingToType(mat_g1,
                                               expected_mat_g1,
                                               atol=1e-1)

            mat_g2 = np.dot(grad_np.transpose(), grad_np)
            expected_mat_g2 = self.evaluate(
                opt.get_slot(var, 'mat_statistics_1'))
            self.assertAllCloseAccordingToType(mat_g2,
                                               expected_mat_g2,
                                               atol=1e-1)

            compute_preconditioner_op.run()
            assign_preconditioners_to_vars_op.run()

            mat_left = np_power(mat_g1, -0.25)
            expected_mat_left = self.evaluate(
                opt.get_slot(var, 'mat_preconditioner_0'))
            self.assertAllCloseAccordingToType(mat_left,
                                               expected_mat_left,
                                               atol=1e-1)

            mat_right = np_power(mat_g2, -0.25)
            expected_mat_right = self.evaluate(
                opt.get_slot(var, 'mat_preconditioner_1'))
            self.assertAllCloseAccordingToType(mat_right,
                                               expected_mat_right,
                                               atol=1e-1)

            # As the preconditioners are initialized to all zero. We don't make
            # any update.
            var_step_0_val = self.evaluate(var)
            self.assertAllCloseAccordingToType(init_var_np,
                                               var_step_0_val,
                                               atol=1e-1)

            # Run another step of training.
            update.run()
            var_step_1_val = self.evaluate(var)

            # New update has the scale of the second diagonal adagrad update.
            adagrad_update = grad_np / np.sqrt(2 * np.square(grad_np))
            preconditioned_grad_update = np.dot(np.dot(mat_left, grad_np),
                                                mat_right)

            # With normalization by diagonal enabled.
            var_step_1_np = init_var_np - preconditioned_grad_update * norm(
                adagrad_update) / norm(preconditioned_grad_update)
            self.assertAllCloseAccordingToType(var_step_1_np,
                                               var_step_1_val,
                                               atol=1e-1)

            # Compute new preconditioners.
            compute_preconditioner_op.run()
            assign_preconditioners_to_vars_op.run()

            # Gradients are summed over time.
            mat_g1 += np.dot(grad_np, grad_np.transpose())
            mat_left = np_power(mat_g1, -0.25)
            expected_mat_left = self.evaluate(
                opt.get_slot(var, 'mat_preconditioner_0'))
            self.assertAllCloseAccordingToType(mat_left,
                                               expected_mat_left,
                                               atol=1e-1)

            mat_g2 += np.dot(grad_np.transpose(), grad_np)
            mat_right = np_power(mat_g2, -0.25)
            expected_mat_right = self.evaluate(
                opt.get_slot(var, 'mat_preconditioner_1'))
            self.assertAllCloseAccordingToType(mat_right,
                                               expected_mat_right,
                                               atol=1e-1)
コード例 #6
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super().__init__(train_cfg, model_task_name, logdir, tf_master,
                         **kwargs)

        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'trainer_params.txt')
        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        # BaseRunner legacy
        self.enqueue_ops = None

        train_cfg = self.params

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                # tpu.initialize_system() is called with None as embedding_config, as
                # embedding_config is not available yet. Later in _Loop, it is called
                # with the correct embedding_config. Since it cannot be called twice in
                # the same graph with different embedding_config, we use a dummy_graph
                # here.
                dummy_graph = tf.Graph()
                with dummy_graph.as_default():
                    tpu_initialize_system_op = tf.tpu.initialize_system(
                        embedding_config=None, job=job)

                with self._GetSession(graph=dummy_graph) as sess:
                    topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_device_order_mode is None:
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism)
                else:
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split, topology),
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='init_start', value=None)
        if len(self._cluster.all_worker_names) > 1:
            for worker in self._cluster.all_worker_names:
                _WaitTillInit(worker)
        else:
            _WaitTillInit(None)

        shared_model = self._MaybeConstructSharedModel(train_cfg)

        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            # If the model was created above, we'll inject it here as a shared_model.
            ps = program_schedule_params.Instantiate(shared_model=shared_model,
                                                     tf_master=self._tf_master)
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()
            if program_schedule_params.ml_perf.benchmark_name is not None:
                self._ml_perf = program_schedule_params.ml_perf

        tf.logging.info('num_programs: %d', len(self._programs))

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(self._cluster.GetPlacer()):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    _ = py_utils.GetOrCreateGlobalStepVar()
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                        py_utils.ClearTpuSummaryTensors()

                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()
                self._initialize_global_vars = tf.global_variables_initializer(
                )

                for program in self._programs:
                    program.SetStatusMessageFn(self._SetStatusMessage)
                    program.CreateCheckpointer(
                        init_op=self._initialize_global_vars)

                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    init_op=self._initialize_global_vars,
                    train_params=train_cfg.train,
                    save_only=True)

            self._load_ops = tf.get_collection(py_utils.TPU_EMBEDDING_LOAD_OPS)
            self._retrieve_ops = tf.get_collection(
                py_utils.TPU_EMBEDDING_RETRIEVE_OPS)
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            self._tpu_embedding = (tpu_embedding_collection[0]
                                   if tpu_embedding_collection else None)
            tf.io.write_graph(self._graph.as_graph_def(), self._checkpoint_dir,
                              'train.pbtxt')
コード例 #7
0
ファイル: executor.py プロジェクト: jack-morrison/lingvo
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
          if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super(ExecutorTpu, self).__init__(train_cfg, model_task_name, logdir,
                                          tf_master, **kwargs)
        self._cluster_def = self._cluster.worker_cluster_def

        # There is a single Executor task
        assert self._cluster.num_replicas == 1
        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None

        self._variable_renaming_rules = []

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = ps_params_dict.keys()[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            ps = program_schedule_params.Instantiate()
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()

        tf.logging.info('num_programs: %d', len(self._programs))

        # BaseRunner legacy
        self.enqueue_ops = None

        def ComputationShape(split_size):
            """Decides the computation shape based on the split_size."""
            computation_shape = None
            if split_size == 1:
                computation_shape = [1, 1, 1]
            elif split_size == 2:
                computation_shape = [1, 1, 2]
            elif split_size == 4:
                computation_shape = [1, 2, 2]
            elif split_size == 8:
                computation_shape = [2, 2, 2]
            elif split_size == 16:
                computation_shape = [4, 2, 2]
            else:
                assert False, (
                    'Model parallelism with %d devices is currently not'
                    ' supported.' % split_size)
            assert computation_shape is not None
            return computation_shape

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit():
            """Wait until the model is ready."""
            try:
                with self._GetSession(cluster_def=self._cluster_def) as sess:
                    topology = sess.run(
                        tf.tpu.initialize_system(embedding_config=None,
                                                 job=None))
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=ComputationShape(
                            num_devices_per_split),
                        num_replicas=data_parallelism)
                    py_utils.SetTpuDeviceAssignment(device_assignment)
                    tf.logging.info('device_assignment.core_assignment: %s',
                                    str(device_assignment.core_assignment))
                    tf.logging.info(
                        'device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        _WaitTillInit()

        with self._graph.as_default(), tf.container(self._container_id):
            with self._cluster, tf.device(self._cluster.job_spec.name):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                    self.initialize_tables = tf.tables_initializer()
                    self._initialize_local_vars = tf.local_variables_initializer(
                    )

                self._checkpoint_dir = os.path.join(logdir, 'train')
                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    train_params=train_cfg.train,
                    save_only=True)
コード例 #8
0
    def Export(cls,
               model_cfg,
               model_task_name=None,
               device_options=InferenceDeviceOptions(
                   device='',
                   retain_device_placement=False,
                   var_options=None,
                   gen_init_op=True,
                   dtype_override=None,
                   fprop_dtype_override=None),
               freeze_checkpoint=None,
               freeze_defaults=False,
               export_path=None,
               subgraph_filter=None,
               random_seed=None,
               disable_packed_input=True):
        """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing
    and multi-core inference on TPUs work properly.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A string or a list of subgraph names. If not None or
        empty, export only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
        assert issubclass(model_cfg.cls, base_model.BaseModel)
        if device_options.dtype_override and device_options.fprop_dtype_override:
            raise ValueError(
                'device_options{dtype_override,fprop_dtype_override) can not both be'
                'set.')
        if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)):
            subgraph_filter = [subgraph_filter]

        # Disable assertions unless user explicitly enables it.
        if FLAGS['enable_asserts'].using_default_value:
            FLAGS.enable_asserts = False

        # TODO(laurenzo): Work out how much we need to specify here in terms of
        # cluster configuration.
        cls._SetClusterParams(model_cfg.cluster, device_options)

        # Configure the model.
        model_cfg.random_seed = random_seed
        model_cfg.is_inference = True

        if disable_packed_input:

            def _DisablePackedInput(task):
                if (_ParamExists(task, 'encoder')
                        and _ParamExists(task.encoder, 'packed_input')):
                    task.encoder.packed_input = False
                if (_ParamExists(task, 'decoder')
                        and _ParamExists(task.decoder, 'packed_input')):
                    task.decoder.packed_input = False

            if issubclass(model_cfg.cls, base_model.MultiTaskModel):
                for _, task_param in model_cfg.task_params.IterParams():
                    _DisablePackedInput(task_param)
            else:
                _DisablePackedInput(model_cfg.task)

        tf.logging.info('Model %s params:', model_cfg.name)
        for line in model_cfg.ToText().split('\n'):
            tf.logging.info('%s', line)

        # Instantiate the graph.
        graph = tf.Graph()
        with graph.as_default():
            tf.random.set_seed(random_seed)
            cluster = model_cfg.cluster.Instantiate()
            device = cluster.GetPlacer()
            tpu_const_scope = _DummyScope()
            if (IsTpu(device_options)
                    and device_options.var_options == 'AS_CONSTANTS'):
                # Do not specify devices for variables if we are marking them as
                # constants.
                device = ''
                tpu_const_scope = ConstGuaranteeScope()

            with cluster, tf.device(device), tpu_const_scope:

                bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
                    device_options)

                if bfloat16_override:
                    py_utils.UpdateDtype(model_cfg, tf.bfloat16)
                    py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

                act_bfloat16_override = ShouldForceBfloat16ForActivations(
                    device_options)
                if act_bfloat16_override:
                    py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

                # Hard-code TPU-related flags prior to instantiating model.
                old_enable_asserts = FLAGS.enable_asserts
                old_xla_device = FLAGS.xla_device
                if IsTpu(device_options):
                    FLAGS.enable_asserts = False
                    FLAGS.xla_device = 'tpu'

                try:
                    mdl = model_cfg.Instantiate()
                    task = mdl.GetTask(model_task_name)

                    variables_to_restore = (_MakeVariableDictionary(
                        tf.global_variables()) if not mdl.ema else
                                            mdl.ema.variables_to_restore(
                                                mdl.variables_for_ema))

                    if bfloat16_override:
                        saver_var_spec = (
                            bfloat16_variables.
                            get_saver_spec_for_variables_with_bf16_overrides(
                                variables_to_restore))
                    else:
                        saver_var_spec = variables_to_restore

                    saver = tf.train.Saver(saver_var_spec)
                    tf.variables_initializer(tf.global_variables(),
                                             name='init_all_variables')
                    if IsTpu(device_options) and device_options.gen_init_op:
                        tf.group(tf.tpu.initialize_system(),
                                 name='tpu_init_op')

                    if freeze_checkpoint or freeze_defaults:
                        # Replace variables with tensors using tf.identity in theta before
                        # freezing to avoid the graph referencing types of DT_RESOURCE.
                        def AddIdentityToTheta(layer):
                            layer._private_theta = layer._private_theta.Transform(
                                tf.identity)  # pylint: disable=protected-access
                            layer.children.Transform(AddIdentityToTheta)

                        AddIdentityToTheta(task)

                    inference_graph_proto = inference_graph_pb2.InferenceGraph(
                    )
                    subgraphs_proto = task.Inference()
                    if isinstance(subgraphs_proto, dict):
                        subgraphs_proto = ConvertSubgraphDictToProto(
                            subgraphs_proto)
                    for name, subgraph in subgraphs_proto.subgraphs.items():
                        if not subgraph_filter or name in subgraph_filter:
                            inference_graph_proto.subgraphs[name].CopyFrom(
                                subgraph)

                    # Yes, graph collections are bad, however this seems to be the
                    # easiest way to get this assets registered from
                    # TextFileInitializer.
                    assets_collection = tf.compat.v1.get_collection(
                        tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
                    for asset in assets_collection:
                        if asset.op.type == 'Const' and asset.op.get_attr(
                                'dtype') == tf.dtypes.string:
                            constant_value = asset.op.get_attr('value')
                            if constant_value.string_val:
                                tf.logging.info('Found asset file_path: %s',
                                                constant_value.string_val[0])
                                asset_file_def = inference_graph_proto.asset_file_def.add(
                                )
                                asset_file_def.tensor_info.name = asset.name
                                asset_file_def.filename = constant_value.string_val[
                                    0]

                    # Add a table init op and global variable init op to the graph.
                    # Tables can be declared anywhere in the graph, so this op has to be
                    # added last.
                    tf.tables_initializer(name='init_all_tables')
                finally:
                    # Reset TPU-related flags after model instantiation.
                    FLAGS.enable_asserts = old_enable_asserts
                    FLAGS.xla_device = old_xla_device

        tf.logging.info('Graph contains ops: %r',
                        [op.name for op in graph.get_operations()])

        inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

        # Freezing.
        if freeze_defaults or freeze_checkpoint:
            output_op_names = GetOutputOpNames(graph,
                                               inference_graph_proto,
                                               preserve_colocation_nodes=False)
            if cls._DeviceSupportsFreezing(device_options):
                raise ValueError(
                    'freeze_checkpoint cannot be used with device ' +
                    device_options.device)
            if freeze_checkpoint:
                tf.logging.info('Freezing graph from checkpoint: %s',
                                freeze_checkpoint)
                graph_def = _FreezeGraphFromCheckpoint(graph, saver,
                                                       freeze_checkpoint,
                                                       output_op_names)
            elif freeze_defaults:
                tf.logging.info('Default initializing graph and freezing.')
                graph_def = _FreezeDefaults(graph, output_op_names)
        else:
            output_op_names = GetOutputOpNames(graph, inference_graph_proto)

            # Prune the graph to just the parts we need.
            # To support restoring, we have to not prune out the restore node.
            output_op_names.append('init_all_tables')
            output_op_names.append('init_all_variables')
            output_op_names.append('save/control_dependency')
            output_op_names.append('save/restore_all')
            if IsTpu(device_options) and device_options.gen_init_op:
                output_op_names.append('tpu_init_op')
            graph_def = graph.as_graph_def()
            tf.logging.info('Pruning graph to output ops: %r', output_op_names)
            graph_def = tf.graph_util.extract_sub_graph(
                graph_def, output_op_names)

        if not device_options.retain_device_placement:
            # Clear the device so that the runtime can choose.
            tf.logging.info('Clearing device placement for: %s',
                            device_options.device)
            for node in graph_def.node:
                node.ClearField('device')
            for function in graph_def.library.function:
                for node_def in function.node_def:
                    node_def.ClearField('device')

        inference_graph_proto.graph_def.CopyFrom(graph_def)

        if export_path:
            with tf.io.gfile.GFile(export_path, 'w') as f:
                f.write(text_format.MessageToString(inference_graph_proto))
        return inference_graph_proto
コード例 #9
0
    def __init__(self, decoder_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._job_name = 'decoder_' + decoder_type
        self.params.cluster.do_eval = True
        self._cluster = cluster_factory.Cluster(self.params.cluster)
        self._decoder_dir = GetDecoderDir(self._logdir, self._job_name,
                                          self._model_task_name)
        tf.io.gfile.makedirs(self._decoder_dir)

        self._decode_path = None
        # Multitask params doesn't have 'task'.
        if 'task' in self.params:
            self._decode_path = checkpointer.GetSpecificCheckpoint(
                self.params.task.eval.load_checkpoint_from)

        self._should_report_metrics = self._job_name.startswith(
            self._cluster.reporting_job)

        with self._graph.as_default(), tf.container(self._container_id):
            self._summary_writer = self._CreateSummaryWriter(self._decoder_dir)
            self._CreateTF2SummaryWriter(self._decoder_dir)
            with self._cluster, tf.device(
                    self._cluster.GetPlacer()), self._TF2SummaryContext():
                self._model = self.params.Instantiate()
                self._params = self._model.params
                self._task = self._model.GetTask(self._model_task_name)
                # Note, different graphs are being constructed for different model
                # tasks, which may result in different node names being chosen.
                # Obviously, variable names has to be stay the same between train and
                # decode.
                cluster = self._cluster
                with tf.device(cluster.input_device):
                    input_batch = self._task.input_generator.GetPreprocessedInputBatch(
                    )

                self._dec_output = self._task.Decode(input_batch)

                for key in self._task.input_generator.GetCpuPassthroughKeys():
                    if key in input_batch:
                        if key in self._dec_output:
                            tf.logging.warning(
                                f'Key {key} already present in decode output. '
                                f'Not adding from input batch.')
                        else:
                            self._dec_output[key] = input_batch[key]

                self._summary_op = tf.summary.merge_all()
                self.checkpointer = self._CreateCheckpointer(
                    self._train_dir, self._model)
            self._CreateTF2SummaryOps()
            self._initialize_tables = tf.tables_initializer()
            self._initialize_local_vars = tf.local_variables_initializer()
            # No queues are allowed for decoder models.
            self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS)
            assert not self.enqueue_ops

        # Saves the graph def.
        self._WriteToLog(self.params.ToText(), self._decoder_dir, 'params.txt')
        if self.params.cluster.task == 0:
            tf.io.write_graph(self._graph.as_graph_def(), self._decoder_dir,
                              '%s.pbtxt' % self._job_name)