Exemplo n.º 1
0
 def _Apply():
     """Use the matched optimizer to apply the gradients."""
     train_ops = []
     non_default_regex = [
         regex for regex in self._optimizer_map
         if regex != 'default_optimizer'
     ]
     for regex in self._optimizer_map:
         if var_grad_map[regex]:
             opt = tf_optimizer_map[regex]
             train_ops.append(opt.apply_gradients(var_grad_map[regex]))
             # pylint: disable=cell-var-from-loop, g-long-lambda
             if regex == 'default_optimizer':
                 filtered_var_grad = var_grad.FilterKeyVal(
                     lambda k, v: any([
                         re.match(i, v.var.name)
                         for i in non_default_regex
                     ]))
             else:
                 filtered_var_grad = var_grad.FilterKeyVal(
                     lambda k, v: (re.match(regex, v.var.name)))
             # pylint: enable=cell-var-from-loop, g-long-lambda
             self._optimizer_map[regex].AddSummary(
                 self._lr_map[regex], opt, filtered_var_grad)
     return tf.group(*train_ops, name='composite_optimizer_train_op')
Exemplo n.º 2
0
 def PostTrainingStepUpdate(self, global_step):
     ops = [
         super(PassiveAsymQDomain, self).PostTrainingStepUpdate(global_step)
     ]
     for t_name in self._t_names:
         ops.extend(self._RecordTensor(t_name))
         self._SummarizeTensor(t_name)
     return tf.group(ops)
    def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
Exemplo n.º 4
0
 def _ApplyAndReset():
     with tf.control_dependencies([
             self._opt.Apply(
                 lr,
                 py_utils.ApplyGradMultiplier(var_grad,
                                              1. / p.accum_steps))
     ]):
         return tf.group(*[
             tf.assign(a, tf.zeros_like(a))
             for _, a in var_grad.Flatten()
         ])
Exemplo n.º 5
0
    def PostTrainingStepUpdate(self, global_step):
        """Returns a TF op which will be invoked at each training step.

    Subclasses of `BaseLayer` can implement this method. The method should
    return a TF op to be invoked during training after gradients are applied.

    Args:
      global_step: the global step.
    """
        update_ops = [
            child.PostTrainingStepUpdate(global_step)
            for child in self._private_children.Flatten()
        ]
        return tf.group(*update_ops)
Exemplo n.º 6
0
    def ApplyPostTrainingLoop(self, global_step):
        """Apply any computation to run after each tpu training loop for each optimizer.

    Args:
      global_step: Global step variable.

    Returns:
      Ops to run after training loop ends.
    """
        post_training_ops = [
            opt.ApplyPostTrainingLoop(global_step)
            for _, opt in self._optimizer_map.items()
        ]
        return tf.group(*post_training_ops)
Exemplo n.º 7
0
 def PostTrainingStepUpdate(self, global_step):
   """Updates moving_mean, moving_variance after each training step."""
   p = self.params
   # Get sufficient stats that accumulates over microbatches.
   counts = self.accumulators.counts.GetValue()
   mean_ss = self.accumulators.mean_ss.GetValue()
   variance_ss = self.accumulators.variance_ss.GetValue()
   # Compute batch mean and batch variance from sufficient stats
   mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None)
   decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype)
   # Update moving_mean, moving_variance from  batch mean and batch variance.
   with tf.name_scope(p.name) as scope:
     with tf.ops.colocate_with(self.vars.moving_mean):
       mean_update = tf.assign_sub(
           self.vars.moving_mean,
           tf.where(
               tf.greater(counts, 0.5),
               (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay,
               tf.zeros_like(self.vars.moving_mean)),
           name='moving_mean_update')
     with tf.ops.colocate_with(self.vars.moving_variance):
       var_update = tf.assign_sub(
           self.vars.moving_variance,
           tf.where(
               tf.greater(counts, 0.5),
               (self.vars.moving_variance - tf.cast(variance, p.dtype)) *
               decay, tf.zeros_like(self.vars.moving_variance)),
           name='moving_variance_update')
     py_utils.CheckNumerics(
         self.vars.moving_mean,
         'moving mean of {} failed numeric check'.format(scope))
     py_utils.CheckNumerics(
         self.vars.moving_variance,
         'moving variance of {} failed numeric check'.format(scope))
   self.accumulators.counts.Reset()
   self.accumulators.mean_ss.Reset()
   self.accumulators.variance_ss.Reset()
   return tf.group(mean_update, var_update)
Exemplo n.º 8
0
    def Apply(self, lr, var_grad):
        p = self.params

        def _Acc(vg):
            """Updating accumulators."""

            v, g = vg
            with tf.variable_scope(v.op.name):
                _, a = py_utils.CreateVariable(
                    'grad_accumulator',
                    py_utils.WeightParams(v.get_shape(),
                                          py_utils.WeightInit.Constant(0.0),
                                          self.params.dtype),
                    trainable=False)
                a = tf.assign_add(a, g)

            return py_utils.VarGrad(v, a)

        var_grad = var_grad.Transform(_Acc)

        def _ApplyAndReset():
            with tf.control_dependencies([
                    self._opt.Apply(
                        lr,
                        py_utils.ApplyGradMultiplier(var_grad,
                                                     1. / p.accum_steps))
            ]):
                return tf.group(*[
                    tf.assign(a, tf.zeros_like(a))
                    for _, a in var_grad.Flatten()
                ])

        return tf.cond(
            tf.equal(tf.math.floormod(self.theta.global_step, p.accum_steps),
                     p.accum_steps - 1), _ApplyAndReset,
            lambda: tf.group(tf.no_op()))
Exemplo n.º 9
0
    def Export(cls,
               model_cfg,
               model_task_name=None,
               device_options=InferenceDeviceOptions(
                   device='',
                   retain_device_placement=False,
                   var_options=None,
                   gen_init_op=True,
                   dtype_override=None),
               freeze_checkpoint=None,
               freeze_defaults=False,
               export_path=None,
               subgraph_filter=None,
               random_seed=None,
               disable_packed_input=True):
        """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
        assert issubclass(model_cfg.cls, base_model.BaseModel)

        # Disable assertions unless user explicitly enables it.
        if FLAGS['enable_asserts'].using_default_value:
            FLAGS.enable_asserts = False

        # TODO(laurenzo): Work out how much we need to specify here in terms of
        # cluster configuration.
        cls._SetClusterParams(model_cfg.cluster, device_options)

        # Configure the model.
        model_cfg.random_seed = random_seed
        model_cfg.is_inference = True

        if disable_packed_input:

            def _DisablePackedInput(task):
                if (_ParamExists(task, 'encoder')
                        and _ParamExists(task.encoder, 'packed_input')):
                    task.encoder.packed_input = False
                if (_ParamExists(task, 'decoder')
                        and _ParamExists(task.decoder, 'packed_input')):
                    task.decoder.packed_input = False

            if issubclass(model_cfg.cls, base_model.MultiTaskModel):
                for _, task_param in model_cfg.task_params.IterParams():
                    _DisablePackedInput(task_param)
            else:
                _DisablePackedInput(model_cfg.task)

        tf.logging.info('Model %s params:', model_cfg.name)
        for line in model_cfg.ToText().split('\n'):
            tf.logging.info('%s', line)

        # Instantiate the graph.
        graph = tf.Graph()
        with graph.as_default():
            tf.random.set_seed(random_seed)
            cluster = model_cfg.cluster.Instantiate()
            device = cluster.GetPlacer()
            tpu_const_scope = _DummyScope()
            if (IsTpu(device_options)
                    and device_options.var_options == 'AS_CONSTANTS'):
                # Do not specify devices for variables if we are marking them as
                # constants.
                device = ''
                tpu_const_scope = ConstGuaranteeScope()

            with cluster, tf.device(device), tpu_const_scope:

                bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
                    device_options)

                if bfloat16_override:
                    py_utils.UpdateDtype(model_cfg, tf.bfloat16)
                    py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

                # Hard-code TPU-related flags prior to instantiating model.
                old_enable_asserts = FLAGS.enable_asserts
                old_xla_device = FLAGS.xla_device
                if IsTpu(device_options):
                    FLAGS.enable_asserts = False
                    FLAGS.xla_device = 'tpu'

                # Ensure the global_step variable is created.
                global_step_var = py_utils.GetOrCreateGlobalStepVar()
                global_step = tf.identity(global_step_var,
                                          name='global_step_tensor')

                with py_utils.GlobalStepContext(global_step):
                    try:
                        mdl = model_cfg.Instantiate()
                        variables_to_restore = (_MakeVariableDictionary(
                            tf.global_variables()) if not mdl.ema else
                                                mdl.ema.variables_to_restore(
                                                    mdl.variables_for_ema))

                        if bfloat16_override:
                            saver_var_spec = (
                                bfloat16_variables.
                                get_saver_spec_for_variables_with_bf16_overrides(
                                    variables_to_restore))
                        else:
                            saver_var_spec = variables_to_restore

                        saver = tf.train.Saver(saver_var_spec)
                        tf.variables_initializer(tf.global_variables(),
                                                 name='init_all_variables')
                        if IsTpu(
                                device_options) and device_options.gen_init_op:
                            tf.group(tf.tpu.initialize_system(),
                                     name='tpu_init_op')

                        model_task = mdl.GetTask(model_task_name)

                        inference_graph_proto = inference_graph_pb2.InferenceGraph(
                        )
                        subgraphs_proto = model_task.Inference()
                        if isinstance(subgraphs_proto, dict):
                            subgraphs_proto = ConvertSubgraphDictToProto(
                                subgraphs_proto)
                        for name, subgraph in subgraphs_proto.subgraphs.items(
                        ):
                            if not subgraph_filter or name in subgraph_filter:
                                inference_graph_proto.subgraphs[name].CopyFrom(
                                    subgraph)

                        # Add a table init op and global variable init op to the graph.
                        # Tables can be declared anywhere in the graph, so this op has to be
                        # added last.
                        tf.tables_initializer(name='init_all_tables')
                    finally:
                        # Reset TPU-related flags after model instantiation.
                        FLAGS.enable_asserts = old_enable_asserts
                        FLAGS.xla_device = old_xla_device

        tf.logging.info('Graph contains ops: %r',
                        [op.name for op in graph.get_operations()])

        inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

        # Freezing.
        if freeze_defaults or freeze_checkpoint:
            output_op_names = GetOutputOpNames(graph,
                                               inference_graph_proto,
                                               preserve_colocation_nodes=False)
            if cls._DeviceSupportsFreezing(device_options):
                raise ValueError(
                    'freeze_checkpoint cannot be used with device ' +
                    device_options.device)
            if freeze_checkpoint:
                tf.logging.info('Freezing graph from checkpoint: %s',
                                freeze_checkpoint)
                graph_def = _FreezeGraphFromCheckpoint(graph, saver,
                                                       freeze_checkpoint,
                                                       output_op_names)
            elif freeze_defaults:
                tf.logging.info('Default initializing graph and freezing.')
                graph_def = _FreezeDefaults(graph, output_op_names)
        else:
            output_op_names = GetOutputOpNames(graph, inference_graph_proto)

            # Prune the graph to just the parts we need.
            # To support restoring, we have to not prune out the restore node.
            output_op_names.append('init_all_tables')
            output_op_names.append('init_all_variables')
            output_op_names.append('save/control_dependency')
            output_op_names.append('save/restore_all')
            if IsTpu(device_options) and device_options.gen_init_op:
                output_op_names.append('tpu_init_op')
            graph_def = graph.as_graph_def()
            tf.logging.info('Pruning graph to output ops: %r', output_op_names)
            graph_def = tf.graph_util.extract_sub_graph(
                graph_def, output_op_names)

        if not device_options.retain_device_placement:
            # Clear the device so that the runtime can choose.
            tf.logging.info('Clearing device placement for: %s',
                            device_options.device)
            for node in graph_def.node:
                node.ClearField('device')
            for function in graph_def.library.function:
                for node_def in function.node_def:
                    node_def.ClearField('device')

        inference_graph_proto.graph_def.CopyFrom(graph_def)

        if export_path:
            with tf.io.gfile.GFile(export_path, 'w') as f:
                f.write(text_format.MessageToString(inference_graph_proto))
        return inference_graph_proto