Beispiel #1
0
 def Inference(self):
   with tf.name_scope('inference'):
     feed1 = tf.placeholder(name='feed1_node', dtype=tf.float32, shape=[1])
     fetch1 = tf.identity(feed1, name='fetch1_node')
     inference_graph = inference_graph_pb2.InferenceGraph()
     subgraph = inference_graph.subgraphs['default']
     subgraph.feeds['feed1'] = feed1.name
     subgraph.fetches['fetch1'] = fetch1.name
     return inference_graph
Beispiel #2
0
 def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                               unused_step_ids, states,
                               unused_num_hyps_per_beam):
   atten_probs = tf.identity(states.atten_probs)
   logits = tf.random.normal([tgt_batch_size, vocab_size], seed=8273747)
   return (py_utils.NestedMap({
       'atten_probs': atten_probs,
       'log_probs': logits
   }), states)
 def __AddVariable(self, var):
   name = self.get_name(var)
   if name in self._activated_var_names:
     tf.logging.info(
         "Warning, already activated variable with name {}".format(name))
     return
   tf.logging.info("Adding variable {}".format(name))
   self._private_vars[name] = var
   self._private_theta[name] = tf.identity(var)
   self._activated_var_names.add(name)
Beispiel #4
0
 def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op):
   """Returns the op for tpu training with tail cpu computation."""
   # Adds a tail computation that is run after the tpu_training loop
   # step finishes. This allows us to run certain computation that
   # acts on the variable between tpu_train_loop iterations and
   # amortizing the cost of the operations. Alternative of running
   # tpu.outside_compilation & using tf.cond is expenseive.
   with tf.control_dependencies(train_loop_op):
     self._model.ConstructPostTrainingLoop()
     with tf.control_dependencies([self._task.post_training_loop_op]):
       return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op])
Beispiel #5
0
  def _CreateLayerVariables(self):
    p = self.params

    # Reuse the singleton table variables if they were created before.
    all_table_vars = self._tpu_embedding_collection.table_variables
    if self.table_name in all_table_vars:
      embedding_table_vars = all_table_vars[self.table_name]
    else:
      w_pc = py_utils.WeightParams(
          shape=[self._ids_per_shard, p.embedding_dim],
          init=p.params_init,
          dtype=p.dtype,
          collections=[self.__class__.__name__ + '_vars'])

      embedding_table_vars = []
      for i in range(p.num_tpu_hosts):
        device_name = self.GetDeviceName(i)
        with tf.device(device_name), py_utils.outside_all_rewrites():
          var_name = self.GetVariableName(i)
          self.CreateVariable(var_name, w_pc)
          embedding_var = self.vars[var_name]
          embedding_table_vars.append(embedding_var)
          # Remove from _private_vars / _private_thetas to be added later as wm.
          _RemovePrivateVar(self, var_name)

      self._tpu_embedding_collection.AddTableVariables(self.table_name,
                                                       embedding_table_vars)

    if not _ShouldUseTpu(p):
      # We don't want to add this for TrainerTpu, otherwise the identity
      # reference leads to copying the embedding to the TPU for no reason.
      # However, this is needed for CPU (eval/decode/controller).
      self._private_vars['wm'] = embedding_table_vars
      self._private_theta['wm'] = [tf.identity(v) for v in embedding_table_vars]

    # If slot variables and load/retrieve ops were created before, maybe by a
    # different program or task, don't create it again.
    # Note that there should be only one copy of slot variables and
    # load/retrieve ops in the graph and they're shared by different
    # tasks/programs.
    all_load_ops = self._tpu_embedding_collection.load_ops
    if self.table_name not in all_load_ops:
      assert self.table_name not in self._tpu_embedding_collection.retrieve_ops
      # Only trainer and controller (for checkpointing) need slot variables.
      # Only trainer needs load/retrieve ops.
      if not self.do_eval and not p.is_inference:
        load_ops, retrieve_ops = self.optimizer.CreateSlotVariablesAndOps(
            embedding_table_vars, self)
        self._tpu_embedding_collection.AddLoadRetrieveOps(
            self.table_name, load_ops, retrieve_ops)
Beispiel #6
0
    def CreateVariable(self, name: str, var_params: hyperparams.Params,
                       **kwargs) -> None:
        """Create a variable of this layer according to the parameter `var_params`.

    E.g.::

        def __init__(self, ...):    # A layer's constructor
          self.CreateVariable(
              'weight', py_utils.WeightParams(shape=[100, 100]))

    Args:
      name: Variable name which is used as the key into vars/theta.
      var_params: `Params` used to create the variable.
      **kwargs: Keyword args passed to `.py_utils.CreateVariable`.
    """
        kwargs.setdefault('default_seed', self.params.random_seed)
        if self.params.device_mesh is not None:
            if (len([dim for dim in var_params.shape if dim > 1]) > 1
                    and var_params.tensor_split_dims_mapping is None):
                tf.logging.warning(
                    'tensor_split_dims_mapping missing for %s.%s: shape=%s',
                    self.path, name, var_params.shape)
        self._CheckName(name)
        if (self.params.skip_lp_regularization and
                py_utils.SKIP_LP_REGULARIZATION not in var_params.collections):
            var_params = py_utils.WeightParams(
                shape=var_params.shape,
                dtype=var_params.dtype,
                init=var_params.init,
                collections=(var_params.collections +
                             [py_utils.SKIP_LP_REGULARIZATION]))
        self._var_symbolic_shape_map[name] = var_params.shape

        var = py_utils.CreateVariable(name, var_params, **kwargs)
        self._private_vars[name] = var

        if py_utils.IsEagerMode():
            # With eager trainer, always use the variable directly.
            value = var
        else:
            if self.cluster.params.worker.gpus_per_replica > 0:
                # On GPU (which always trains a single step per session.run()),
                # reference a tensor in FProp to cache it on device and avoid extraneous
                # sends from reading variables from ps multiple times.
                with tf.device(var.device):
                    value = tf.identity(var, name=name)
            else:
                value = var

        self._private_theta[name] = value
Beispiel #7
0
  def __init__(self, params):
    """Initializes this Model."""
    assert issubclass(params.cls, BaseModel)
    self._global_step_var = py_utils.GetOrCreateGlobalStepVar()
    self._global_step = tf.identity(
        self._global_step_var, name='global_step_tensor')
    super(BaseModel, self).__init__(params)

    self._ema = None
    tp = self.params.train
    tf.logging.info('Training parameters for %s: %s', params.cls, tp)
    if tp.ema_decay > 0:
      assert tp.ema_decay < 1.0
      self._ema = tf.train.ExponentialMovingAverage(
          decay=tp.ema_decay, num_updates=self.global_step)
Beispiel #8
0
def BeamSearchDecodeOutputToDecoderTopK(decoder_outs,
                                        *,
                                        ids_to_strings_fn,
                                        tag=''):
    """Converts BeamSearchDecodeOutput to DecoderTopK.

  As a side-effect, also creates TF nodes used by eval pipelines
  ("top_k_decoded" and "top_k_scores").

  Args:
    decoder_outs: a beam_search_helper.BeamSearchDecodeOutput instance.
    ids_to_strings_fn: a function of (ids, lens) -> strings, where ids has shape
      [batch, length], lens has shape [batch], and strings has shape [batch].
    tag: optional tag for tf.identity() names.

  Returns:
    A DecoderTopK instance.
  """
    hyps = decoder_outs.topk_hyps
    ids = decoder_outs.topk_ids
    lens = tf.identity(decoder_outs.topk_lens, name='TopKLabelLengths' + tag)
    scores = decoder_outs.topk_scores
    decoded = decoder_outs.topk_decoded

    if decoder_outs.topk_ids is not None:
        ids = tf.identity(ids, name='TopKLabelIds' + tag)
        # With the assumption that ids[-1] is always EOS token.
        # TODO(b/195027707): remove EOS token in better way.
        decoded = ids_to_strings_fn(ids, lens - 1)
        decoded = tf.identity(decoded, name='top_k_decoded%s' % tag)
        decoded = tf.reshape(decoded, tf.shape(scores))
    if scores is not None and hyps is not None:
        scores = tf.identity(tf.reshape(scores, tf.shape(lens)),
                             name='top_k_scores%s' % tag)
        scores = tf.reshape(scores, tf.shape(hyps))
    return DecoderTopK(hyps, ids, lens, scores, decoded)
 def Inference(self):
   if py_utils.use_tpu():
     raise NotImplementedError('TPU is not supported.')
   with tf.name_scope('inference'):
     feed1 = tf.placeholder(name='feed1_node', dtype=tf.float32, shape=[1])
     fetch1 = tf.identity(feed1, name='fetch1_node')
     return {
         'default': (
             py_utils.NestedMap({
                 'fetch1': fetch1,
                 'fetch_op': fetch1.op,  # Tests that ops are supported.
             }),
             py_utils.NestedMap({
                 'feed1': feed1,
             })),
         'unused': (py_utils.NestedMap({}), py_utils.NestedMap({})),
     }
Beispiel #10
0
    def _CreateVariableInternal(self, name, meta):
        """Immediately creates the variable described by `meta`.

    DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use
    self.CreateVariable() to create variables.

    Args:
      name: The variable name.
      meta: A CreateVariableMeta describing the variable to be created.
    """
        meta.kwargs.setdefault('default_seed', self.params.random_seed)
        var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs)
        self._private_vars[name] = var
        if resource_variable_ops.is_resource_variable(var):
            value = var
        else:
            with tf.device(var.device):
                value = tf.identity(var)
        if meta.theta_fn is not None:
            value = meta.theta_fn(value)
        self._private_theta[name] = value
Beispiel #11
0
  def _CreateVariableInternal(self, name: str,
                              meta: CreateVariableMeta) -> None:
    """Immediately creates the variable described by `meta`.

    DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use
    self.CreateVariable() to create variables.

    Args:
      name: The variable name.
      meta: A CreateVariableMeta describing the variable to be created.
    """
    meta.kwargs.setdefault('default_seed', self.params.random_seed)
    var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs)
    self._private_vars[name] = var
    if self.cluster.params.worker.gpus_per_replica > 0:
      # On GPU (which always trains a single step per session.run()), reference
      # a tensor in FProp to cache it on device and avoid extraneous sends from
      # reading variables from ps multiple times.
      with tf.device(var.device):
        value = tf.identity(var)
    else:
      # Pass the resource variable directly into the training loop.
      value = var

    # Due to b/174956514, we have to annotate the use of the variable once,
    # otherwise, the sharding annotation on the var will be ignored.
    # TODO(yonghui): Get rid of this once b/174956514 is fixed.
    if (meta.var_params.device_mesh is not None and
        var.shape.rank == len(meta.var_params.tensor_split_dims_mapping)):
      value = gshard_utils.MeshSplit(
          value,
          meta.var_params.device_mesh,
          meta.var_params.tensor_split_dims_mapping,
          use_sharding_op=True)

    if meta.theta_fn is not None:
      self._private_theta_fn[name] = meta.theta_fn

    self._private_theta[name] = value
Beispiel #12
0
def _WrapNonLingvoVars(dest_layer: base_layer.BaseLayer,
                       variables: Collection[tf.Variable],
                       trainable_variables: Collection[tf.Variable] = ()):
    """Adds variables to the given lingvo layer and appropriate graph collections.

  This function helps wrap variables created outside of lingvo so they are
  correctly handled by lingvo's trainer and checkpointer. It does the following:

    - makes all `variables` trackable through `dest_layer.vars`;
    - ensures `variables` are in the `tf.global_variables()` graph collection so
      the trainer can initialize them;
    - adds the `trainable_variables` subset to the `tf.trainable_variables()`
      graph collection, so they are visible to the learner (i.e. can be
      trained).

  Args:
    dest_layer: Lingvo layer to add the `variables` to.
    variables: The non-lingvo variables to wrap.
    trainable_variables: The subset of `variables` to ensure are trainable.
  """

    global_collection = set(tf.global_variables())
    for v in variables:
        assert v in global_collection
        name = v.name.split(':')[0]
        # pylint: disable=protected-access
        dest_layer._private_vars[name] = v
        with tf.device(v.device):
            dest_layer._private_theta[name] = tf.identity(v)
        # pylint: enable=protected-access

    trainable_collection = set(tf.trainable_variables())
    for v in trainable_variables:
        if v not in trainable_collection:
            tf.logging.warning(
                'Wrapped var %s not in trainable collection; adding it.',
                v.name)
            tf.compat.v1.add_to_collection(
                tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, v)
Beispiel #13
0
    def _CreateLayerVariables(self):
        p = self.params
        w_pc = py_utils.WeightParams(
            shape=[self._ids_per_shard, p.embedding_dim],
            init=p.params_init,
            dtype=p.dtype,
            collections=[self.__class__.__name__ + '_vars'])

        embedding_table_vars = []
        for i in range(p.num_tpu_hosts):
            device_name = self.GetDeviceName(i)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                var_name = self.GetVariableName(i)
                self.CreateVariable(var_name, w_pc)
                embedding_var = self.vars[var_name]
                embedding_table_vars.append(embedding_var)
                # Remove from _private_vars / _private_thetas to be added later as wm.
                del self._private_vars[var_name]
                del self._private_theta[var_name]

        self._tpu_embedding_collection.AddTableVariables(
            self.table_name, embedding_table_vars)

        if not py_utils.use_tpu():
            # We don't want to add this for TrainerTpu, otherwise the identity
            # reference leads to copying the embedding to the TPU for no reason.
            # However, this is needed for CPU (eval/decode/controller).
            self._private_vars['wm'] = embedding_table_vars
            self._private_theta['wm'] = [
                tf.identity(v) for v in embedding_table_vars
            ]

        # Only trainer and controller need slot variables and load/retrieve ops.
        if not self.do_eval:
            self._load_op_list, self._retrieve_op_list = (
                self.optimizer.CreateSlotVariablesAndOps(
                    embedding_table_vars, self))
  def Export(cls,
             model_cfg,
             model_task_name=None,
             device_options=InferenceDeviceOptions(
                 device='',
                 retain_device_placement=False,
                 var_options=None,
                 gen_init_op=True,
                 dtype_override=None),
             freeze_checkpoint=None,
             freeze_defaults=False,
             export_path=None,
             subgraph_filter=None,
             random_seed=None,
             disable_packed_input=True):
    """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
    assert issubclass(model_cfg.cls, base_model.BaseModel)

    # Disable assertions unless user explicitly enables it.
    if FLAGS['enable_asserts'].using_default_value:
      FLAGS.enable_asserts = False

    # TODO(laurenzo): Work out how much we need to specify here in terms of
    # cluster configuration.
    cls._SetClusterParams(model_cfg.cluster, device_options)

    # Configure the model.
    model_cfg.random_seed = random_seed
    model_cfg.is_inference = True

    if disable_packed_input:

      def _DisablePackedInput(task):
        if (_ParamExists(task, 'encoder') and
            _ParamExists(task.encoder, 'packed_input')):
          task.encoder.packed_input = False
        if (_ParamExists(task, 'decoder') and
            _ParamExists(task.decoder, 'packed_input')):
          task.decoder.packed_input = False

      if issubclass(model_cfg.cls, base_model.MultiTaskModel):
        for _, task_param in model_cfg.task_params.IterParams():
          _DisablePackedInput(task_param)
      else:
        _DisablePackedInput(model_cfg.task)

    tf.logging.info('Model %s. Params: %s', model_cfg.name,
                         model_cfg.ToText())

    # Instantiate the graph.
    graph = tf.Graph()
    with graph.as_default():
      tf.random.set_seed(random_seed)
      cluster = model_cfg.cluster.Instantiate()
      device = cluster.GetPlacer()
      tpu_const_scope = _DummyScope()
      if (IsTpu(device_options) and
          device_options.var_options == 'AS_CONSTANTS'):
        # Do not specify devices for variables if we are marking them as
        # constants.
        device = ''
        tpu_const_scope = ConstGuaranteeScope()

      with cluster, tf.device(device), tpu_const_scope:

        bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
            device_options)

        if bfloat16_override:
          py_utils.UpdateDtype(model_cfg, tf.bfloat16)
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        # Hard-code TPU-related flags prior to instantiating model.
        old_enable_asserts = FLAGS.enable_asserts
        old_xla_device = FLAGS.xla_device
        if IsTpu(device_options):
          FLAGS.enable_asserts = False
          FLAGS.xla_device = 'tpu'

        # Ensure the global_step variable is created.
        global_step_var = py_utils.GetOrCreateGlobalStepVar()
        global_step = tf.identity(global_step_var, name='global_step_tensor')

        with py_utils.GlobalStepContext(global_step):
          try:
            mdl = model_cfg.Instantiate()
            variables_to_restore = (
                _MakeVariableDictionary(tf.global_variables()) if not mdl.ema
                else mdl.ema.variables_to_restore(mdl.variables_for_ema))

            if bfloat16_override:
              saver_var_spec = (
                  bfloat16_variables
                  .get_saver_spec_for_variables_with_bf16_overrides(
                      variables_to_restore))
            else:
              saver_var_spec = variables_to_restore

            saver = tf.train.Saver(saver_var_spec)
            tf.variables_initializer(
                tf.global_variables(), name='init_all_variables')
            if IsTpu(device_options) and device_options.gen_init_op:
              tf.group(tf.tpu.initialize_system(), name='tpu_init_op')

            model_task = mdl.GetTask(model_task_name)

            inference_graph_proto = inference_graph_pb2.InferenceGraph()
            subgraphs_proto = model_task.Inference()
            if isinstance(subgraphs_proto, dict):
              subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
            for name, subgraph in subgraphs_proto.subgraphs.items():
              if not subgraph_filter or name in subgraph_filter:
                inference_graph_proto.subgraphs[name].CopyFrom(subgraph)

            # Add a table init op and global variable init op to the graph.
            # Tables can be declared anywhere in the graph, so this op has to be
            # added last.
            tf.tables_initializer(name='init_all_tables')
          finally:
            # Reset TPU-related flags after model instantiation.
            FLAGS.enable_asserts = old_enable_asserts
            FLAGS.xla_device = old_xla_device

    tf.logging.info('Graph contains ops: %r',
                         [op.name for op in graph.get_operations()])

    inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

    # Freezing.
    if freeze_defaults or freeze_checkpoint:
      output_op_names = GetOutputOpNames(
          graph, inference_graph_proto, preserve_colocation_nodes=False)
      if cls._DeviceSupportsFreezing(device_options):
        raise ValueError('freeze_checkpoint cannot be used with device ' +
                         device_options.device)
      if freeze_checkpoint:
        tf.logging.info('Freezing graph from checkpoint: %s',
                             freeze_checkpoint)
        graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
                                               output_op_names)
      elif freeze_defaults:
        tf.logging.info('Default initializing graph and freezing.')
        graph_def = _FreezeDefaults(graph, output_op_names)
    else:
      output_op_names = GetOutputOpNames(graph, inference_graph_proto)

      # Prune the graph to just the parts we need.
      # To support restoring, we have to not prune out the restore node.
      output_op_names.append('init_all_tables')
      output_op_names.append('init_all_variables')
      output_op_names.append('save/control_dependency')
      output_op_names.append('save/restore_all')
      if IsTpu(device_options) and device_options.gen_init_op:
        output_op_names.append('tpu_init_op')
      graph_def = graph.as_graph_def()
      tf.logging.info('Pruning graph to output ops: %r', output_op_names)
      graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names)

    if not device_options.retain_device_placement:
      # Clear the device so that the runtime can choose.
      tf.logging.info('Clearing device placement for: %s',
                           device_options.device)
      for node in graph_def.node:
        node.ClearField('device')
      for function in graph_def.library.function:
        for node_def in function.node_def:
          node_def.ClearField('device')

    inference_graph_proto.graph_def.CopyFrom(graph_def)

    if export_path:
      with tf.io.gfile.GFile(export_path, 'w') as f:
        f.write(text_format.MessageToString(inference_graph_proto))
    return inference_graph_proto
Beispiel #15
0
    def FProp(self, theta, input_tensor):
        p = self.params

        if self._output_tensor is not None:
            raise ValueError('FProp was already called.')

        def _Gradient(inputs, _, original_grad):

            # Compute the gradients for each loss w.r.t. the inputs.
            # TODO(jngiam): Look into whether TF dedups this computation.
            per_loss_grads = []
            for loss, _ in self._losses:
                per_loss_grad = tf.gradients(loss, self._output_tensor)[0]
                if per_loss_grad is None:
                    tf.logging.warning(
                        'Loss %s did not result in a gradient during '
                        'GradDrop computation.', loss)
                else:
                    per_loss_grads.append(per_loss_grad)

            if not per_loss_grads:
                raise ValueError('No valid gradients for GradDrop.')

            # Multiply the gradients with the inputs.
            grads = per_loss_grads
            if p.use_input_sign_only:
                input_abs = tf.abs(
                    tf.cast(tf.abs(inputs) <= p.epsilon, tf.float32) + inputs)
                grads = [grad * ((inputs) / (input_abs)) for grad in grads]
            else:
                grads = [grad * inputs for grad in grads]

            # Sum gradient over batch, assuming that batch is always on dim 0.
            if p.marginalize_batch_dim:
                grads = [
                    tf.reduce_sum(grad, axis=0, keepdims=True)
                    for grad in grads
                ]

            # First discretize all gradients into their sign values.
            grad_sign_positive = [
                tf.cast(grad > 0.0, tf.float32) for grad in grads
            ]
            grad_sign_negative = [
                tf.cast(grad < 0.0, tf.float32) for grad in grads
            ]

            # Calculate the probability of positive gradients based on equation (1)
            # in the GradDrop paper.
            grad_abs_sum = tf.add_n([tf.abs(grad) for grad in grads])
            prob_pos = (tf.add_n(grads) / (2. * grad_abs_sum + p.epsilon))
            # Implementation of different scales for the keep function. Larger
            # scales result in steeper keep functions.
            prob_pos *= p.keep_prob_function_scale

            if p.keep_prob_function == 'sigmoid':
                # Standard sigmoid has derivative of 0.25 at 0 so the factor of 4.0
                # allows the function scale in sigmoid to be compatible with the
                # function scale in the linear case.
                prob_pos = tf.sigmoid(4.0 * prob_pos)
            elif p.keep_prob_function == 'linear':
                prob_pos += 0.5

            # The main, default mode of GradDrop. Only gradients of one sign are kept,
            # and which sign is calculated via equation (1) of the main paper.
            prob_pos = tf.cast(prob_pos >= tf.random.uniform(prob_pos.shape),
                               tf.float32) - 0.5
            grad_masks = [
                (gsp - gsn) * prob_pos >= 0
                for (gsn, gsp) in zip(grad_sign_negative, grad_sign_positive)
            ]

            # This diag value gives us the percentage of grads which are kept.
            gradmask_diag = [tf.cast(gm, tf.float32) for gm in grad_masks]
            diag = tf.reduce_mean(tf.add_n(gradmask_diag) / len(grad_masks))
            summary_utils.scalar('average_grad_mask', diag)
            leak_ratios = [leak_ratio for _, leak_ratio in self._losses]
            transformed_per_loss_grads = [
                grad * (leak + (1.0 - leak) * tf.cast(grad_mask, tf.float32))
                for (leak, grad,
                     grad_mask) in zip(leak_ratios, per_loss_grads, grad_masks)
            ]

            transformed_grad = tf.cast(tf.add_n(transformed_per_loss_grads),
                                       original_grad.dtype)

            if not p.keep_gradnorm_constant:
                return transformed_grad

            transformed_grad_norm = tf.sqrt(tf.reduce_sum(transformed_grad**2))
            original_grad_norm = tf.sqrt(tf.reduce_sum(original_grad**2))
            return transformed_grad * original_grad_norm / (
                transformed_grad_norm + p.epsilon)

        output_tensor = py_utils.CallDefun(tf.identity, input_tensor,
                                           _Gradient)
        self._output_tensor = tf.identity(output_tensor)
        return self._output_tensor
Beispiel #16
0
    def try_apply_dense(self, grad, var):
        assert grad is not None

        cond = tf.constant(True)
        is_finite_checks = []
        stats = {}

        grad_dtype = var.dtype  # TODO(lepikhin): add to params
        grad = tf.cast(grad, grad_dtype)
        factored_dims = self._factored_dims(var.shape.as_list())
        if factored_dims:
            vr = self.get_slot(var, 'vr')
            vc = self.get_slot(var, 'vc')
        else:
            v = self.get_slot(var, 'v')
        if self._beta1:
            m = self.get_slot(var, 'm')

        def _Upd(c, k, x):
            stats[k] = x
            is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x)))
            return c

        with tf.variable_scope(var.name[:-2] + '/Adafactor'):
            grad_squared = tf.math.square(grad) + tf.cast(
                self._epsilon1, grad_dtype)
            cond = _Upd(cond, 'grad_squared', grad_squared)  # 0 (factored)
            decay_rate = tf.cast(self._decay_rate, var.dtype)
            old_val = tf.identity(
                var)  # TODO(lepikhin): introduce gradient dtype
            assert self._multiply_by_parameter_scale
            lr = GetLrValue(self._learning_rate)
            if self._multiply_by_parameter_scale:
                parameter_scale = self._parameter_scale(old_val)
                cond = _Upd(cond, 'parameter_scale',
                            parameter_scale)  # 1 (factored)
                update_scale = self._parameter_scale(old_val) * tf.cast(
                    lr, grad_dtype)

            else:
                update_scale = lr
            mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype)
            update_scale = tf.cast(update_scale, grad_dtype)
            if factored_dims:
                d0, d1 = factored_dims
                vr_axis, vc_axis = d0, d1
                grad_squared_row_mean = tf.reduce_mean(grad_squared,
                                                       axis=vr_axis)
                grad_squared_col_mean = tf.reduce_mean(grad_squared,
                                                       axis=vc_axis)
                # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                cond = _Upd(cond, 'new_vr', new_vr)  # 2 (factored)
                cond = _Upd(cond, 'new_vc', new_vc)  # 3 (factored)
                # vr_update = _Wrap(tf.assign, vr, new_vr)
                # vc_update = _Wrap(tf.assign, vc, new_vc)
                # updates.extend([vr_update, vc_update])
                long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
                r_factor = tf.math.rsqrt(new_vr / long_term_mean)
                c_factor = tf.math.rsqrt(new_vc)
                mult = tf.expand_dims(r_factor, vr_axis) * tf.expand_dims(
                    c_factor, vc_axis)
                cond = _Upd(cond, 'mult', mult)  # 4 (factored)
                x = grad * mult
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                cond = _Upd(cond, 'new_v', new_v)
                # v_update = _Wrap(tf.assign, v, new_v)
                # updates.append(v_update)
                x = grad * tf.math.rsqrt(new_v)

            assert self._clipping_threshold is not None

            if self._clipping_threshold is not None:
                clipping_denom = tf.maximum(
                    tf.constant(1.0, grad_dtype),
                    py_utils.ReduceRms(x) /
                    tf.constant(self._clipping_threshold, grad_dtype))
                x /= clipping_denom
            cond = _Upd(cond, 'x', x)
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) +
                         subtrahend *
                         tf.constant(1.0 - self._beta1, dtype=grad_dtype))
                subtrahend = new_m
                cond = _Upd(cond, 'new_m', new_m)
                # updates.append(_Wrap(tf.assign, m, new_m))

            # It is critical to use assign_sub instead of tf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            cond = _Upd(cond, 'subtrahend', subtrahend)  # 5 (factored)

            # var_update = _Wrap(tf.assign_sub, var, subtrahend)
            # updates.append(var_update)

            return is_finite_checks, stats
Beispiel #17
0
    def _resource_apply_dense(self, grad, var):
        if grad is None:
            tf.logging.warning('Gradient is None for variable %s' % var.name)
            return []

        grad_dtype = var.dtype  # TODO(lepikhin): add to params
        grad = tf.cast(grad, grad_dtype)
        factored_dims = self._factored_dims(var.shape.as_list())
        if factored_dims:
            vr = self.get_slot(var, 'vr')
            vc = self.get_slot(var, 'vc')
        else:
            v = self.get_slot(var, 'v')
        if self._beta1:
            m = self.get_slot(var, 'm')

        cond = tf.constant(True)

        def _Upd(c, x):
            if not self._cond_is_finite:
                return c
            c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x)))
            c = tf.math.logical_and(
                c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x))))
            return c

        def _Wrap(fn, x, y):
            if not self._cond_is_finite:
                return fn(x, y)
            return tf.cond(cond, lambda: fn(x, y), lambda: x)

        with tf.variable_scope(var.name[:-2] + '/Adafactor'):
            grad_squared = tf.math.square(grad) + tf.cast(
                self._epsilon1, grad_dtype)
            cond = _Upd(cond, grad_squared)
            decay_rate = tf.cast(self._decay_rate, var.dtype)
            old_val = tf.identity(
                var)  # TODO(lepikhin): introduce gradient dtype
            lr = GetLrValue(self._learning_rate)
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(old_val) * tf.cast(
                    lr, grad_dtype)
            else:
                update_scale = lr
            mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype)
            update_scale = tf.cast(update_scale, grad_dtype)
            updates = []
            if factored_dims:
                d0, d1 = factored_dims
                vr_axis, vc_axis = d0, d1
                grad_squared_row_mean = tf.reduce_mean(grad_squared,
                                                       axis=vr_axis)
                grad_squared_col_mean = tf.reduce_mean(grad_squared,
                                                       axis=vc_axis)
                # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                cond = _Upd(cond, new_vr)
                cond = _Upd(cond, new_vc)
                vr_update = _Wrap(tf.assign, vr, new_vr)
                vc_update = _Wrap(tf.assign, vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
                r_factor = tf.math.rsqrt(new_vr / long_term_mean)
                c_factor = tf.math.rsqrt(new_vc)
                x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims(
                    c_factor, vc_axis)
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                cond = _Upd(cond, new_v)
                v_update = _Wrap(tf.assign, v, new_v)
                updates.append(v_update)
                x = grad * tf.math.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = tf.maximum(
                    tf.constant(1.0, grad_dtype),
                    py_utils.ReduceRms(x) /
                    tf.constant(self._clipping_threshold, grad_dtype))
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) +
                         subtrahend *
                         tf.constant(1.0 - self._beta1, dtype=grad_dtype))
                subtrahend = new_m
                cond = _Upd(cond, new_m)
                updates.append(_Wrap(tf.assign, m, new_m))
            # It is critical to use assign_sub instead of tf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            cond = _Upd(cond, subtrahend)
            var_update = _Wrap(tf.assign_sub, var, subtrahend)
            updates.append(var_update)
            return tf.group(*updates)
Beispiel #18
0
 def InstantiateVariables(self):
   with py_utils.GlobalStepContext(
       tf.identity(self._global_step_var, name='global_step_tensor')):
     super().InstantiateVariables()
Beispiel #19
0
    def __init__(self,
                 learning_rate,
                 momentum=0.0,
                 initial_accumulator_value=0.0,
                 start_preconditioning_steps=1000,
                 statistics_computation_frequency=1,
                 matrix_epsilon=1e-6,
                 synchronous_preconditioning=False,
                 second_moment_averaging=1.0,
                 fallback_to_diagonal_dim=4096,
                 max_any_dim=6656,
                 block_size=4096,
                 block_partition_threshold_size=1000000,
                 global_step=None,
                 exponent_multiplier=1.0,
                 name="DistributedShampoo"):
        """Construct a DistributedShampoo optimizer.

    Args:
      learning_rate: A `Tensor` or a floating point value.  The learning rate.
      momentum: A `Tensor` or a floating point value. Momentum is not applied to
        sparse updates.
      initial_accumulator_value: A floating point value.
      start_preconditioning_steps: A int32 value which indicates when to start
        preconditioning.
      statistics_computation_frequency: A int32 step value which indicates how
        often to compute statistics for preconditioning.
      matrix_epsilon: An epsilon regularizer to make the matrices positive
        definite.
      synchronous_preconditioning: Whether to run preconditioning synchronously.
      second_moment_averaging: 1.0 means sum of gradients squares, while less
        than 1.0 switches to RMSProp style exponential moving averages of the
        second moments.
      fallback_to_diagonal_dim: Fallback to diagonal version of AFMA if the any
        of the dimension is larger than fallback_to_diagonal_dim.
      max_any_dim: If maximum value for any dimension is greater than this value
        we skip preconditioning and fall back to the diagonal.
      block_size: Dimension of the partitioned tensors.
      block_partition_threshold_size: Partitions diemnsions beyond this size.
      global_step: Global step for training.
      exponent_multiplier: A multiplier 'e` for the exponent for the inverse
        calculation. e * -1/(2*rank). Only applies when calculating inverses
        through svd.
      name: Optional name prefix for the operations created when applying
        gradients.
    """
        super().__init__(False, name)
        self._learning_rate = learning_rate
        self._momentum = momentum
        self._initial_accumulator_value = initial_accumulator_value
        self._start_preconditioning_steps = start_preconditioning_steps
        self._matrix_epsilon = matrix_epsilon
        self._synchronous_preconditioning = synchronous_preconditioning
        self._second_moment_averaging = second_moment_averaging
        self._fallback_to_diagonal_dim = fallback_to_diagonal_dim
        self._max_any_dim = max_any_dim
        self._block_size = block_size
        # NOTE: On XLA - int64 is not handled properly.
        if global_step is not None:
            self._global_step = tf.cast(tf.identity(global_step), tf.int32)
        else:
            self._global_step = tf.cast(
                tf.identity(tf.train.get_or_create_global_step()), tf.int32)
        self._run_nondiagonal_update = tf.greater_equal(
            self._global_step, self._start_preconditioning_steps)
        start_steps_f = tf.cast(self._start_preconditioning_steps, tf.float32)
        global_step_f = tf.cast(self._global_step, tf.float32)
        self._run_nondiagonal_update_warmup = tf.minimum(
            1.0,
            tf.maximum((global_step_f - start_steps_f) / start_steps_f, 0.0))
        # Computes statistics every K steps.
        self._statistics_computation_frequency = statistics_computation_frequency
        self._run_statistics_computation = tf.equal(
            tf.math.floormod(self._global_step,
                             self._statistics_computation_frequency), 0)
        # All vars that are preconditioned.
        self._all_vars_for_preconditioning = []
        self._exponent_multiplier = exponent_multiplier
        self._partition_info = PartitionConfig(block_partition_threshold_size,
                                               block_size)
        self._partitioner_metadata = {}
Beispiel #20
0
  def __init__(self, params):
    assert issubclass(params.cls, BaseTask)
    # Ensure global_step exists before calling super.
    py_utils.GetOrCreateGlobalStepVar()
    super(BaseTask, self).__init__(params)

    p = self.params

    if p.input:
      # TODO(zhifengc): Consider a simpler way to ensure the input
      # generator stops after one epoch.
      if p.is_eval and p.eval:
        seq_inp = issubclass(p.input.cls,
                             base_input_generator.BaseInputGeneratorFromFiles)
        if p.input.num_samples == 0:
          # Dataset size is unknown. Computes eval summary based on num_samples.
          assert p.eval.samples_per_summary > 0
        elif (p.eval.samples_per_summary == 0) or (p.input.num_samples <
                                                   p.eval.samples_per_summary):
          # If we know the dataset size and we want to evaluate the full
          # set, we need to coordinate the input generator to flush out
          # all samples so the evaler and decoder compute metrics on the
          # whole set for each summary step.
          if seq_inp:
            p.input.flush_every_n = p.input.num_samples
          p.eval.samples_per_summary = p.input.num_samples
        if seq_inp and p.input.num_batcher_threads > 1:
          tf.logging.warning('input.num_batcher_threads > 1 inside eval mode.  '
                             'The input generator may not iterate over exactly '
                             'one epoch per run')
      tf.logging.info('input_params: %s', p.input)
      input_params = self.cluster.PlaceInput(p.input)
      with py_utils.outside_all_rewrites():
        self.CreateChild('input', input_params)

    self._encoder = None
    self._online_encoder = None
    self._decoder = None

    self._loss = None
    self._num_predictions = None
    self._train_op = None
    self._eval_metrics = {}
    self._per_example = {}
    self._trainer_verbose_tensors = {}

    # Create the gradient mask,
    self._per_input_gradient_mask = None
    task_global_step_list = tf.get_collection('TASK_GLOBAL_STEP',
                                              '^%s_global_step' % p.name)
    if len(task_global_step_list) > 1:
      raise ValueError('Found multiple task_global_step for task %s' % p.name)
    self._global_step_var = (
        task_global_step_list[0] if len(task_global_step_list) == 1 else
        py_utils.GetOrCreateGlobalStepVar())
    self._global_step = tf.identity(
        self._global_step_var, name='global_step_tensor')
    tp = p.train
    # p.train can be None if this task is the teacher/student task in a
    # DistillationTask.
    if tp and self.cluster.job in ('worker', 'trainer', 'trainer_client',
                                   'controller', 'executor_tpu'):
      self._SetLearnerFromLegacyParams(tp)
      if tp.learner is not None:
        if isinstance(tp.learner, (list, tuple)):
          self.CreateChildren('learners', tp.learner)
        else:
          self.CreateChildren('learners', [tp.learner])
    self._UpdateVnConfig()
Beispiel #21
0
def AddAttentionSummaryBatchMajor(name,
                                  attention_tensors,
                                  src_paddings,
                                  tgt_paddings,
                                  transcripts=None,
                                  max_outputs=3):
    """Adds an image summary showing the attention probability matrix and state.

  As opposed to AddAttentionSummary() takes all tensors with batch dimension in
  axis 0.

  Args:
    name: Summary name.
    attention_tensors: A list of 3D tensors shaped [batch_size, target_len,
      source_len] where attention[b, i, j] is the probability for the i-th
      output attending to the j-th input for element b in the batch.
    src_paddings: A tensor of binary paddings shaped [batch, source_len] for the
      source sequence. Or a list of tensors of the same length as
      attention_tensors with a separate paddings for each entry in
      attention_tensors.
    tgt_paddings: A tensor of binary paddings shaped [batch, target_len] for the
      target sequence. Or a list of tensors of the same length as
      attention_tensors with a separate paddings for each entry in
      attention_tensors.
    transcripts: Optional, transcripts shaped [batch, source_len] for the source
      sequence.
    max_outputs: Integer maximum number of elements of the batch to plot.
  """
    def VerifyLen(paddings):
        length = len(paddings) if isinstance(paddings, list) else 1
        if length != 1 and length != len(attention_tensors):
            raise ValueError('Bad length of paddings list {}'.format(length))

    VerifyLen(src_paddings)
    VerifyLen(tgt_paddings)

    # Verify shapes.
    for i, attention_tensor in enumerate(attention_tensors):
        src, tgt = src_paddings, tgt_paddings
        src = src[0 if len(src) == 1 else i] if isinstance(src, list) else src
        tgt = tgt[0 if len(tgt) == 1 else i] if isinstance(tgt, list) else tgt
        tgt_shape = py_utils.GetShape(tgt)
        attention_tensors[i] = tf.identity(
            py_utils.with_dependencies([
                py_utils.assert_equal(
                    py_utils.GetShape(attention_tensor), tgt_shape[:2] +
                    [py_utils.GetShape(src)[1]] + tgt_shape[2:])
            ], attention_tensor),
            re.sub(':.*$', '', GetTensorName(attention_tensor, name, i)))

    if not _ShouldAddSummary():
        return

    def ToLengths(paddings):
        paddings = paddings if isinstance(paddings, list) else [paddings]
        return [SequenceLength(p) for p in paddings]

    def Get(lengths, i):
        return lengths[0 if len(lengths) == 1 else i]

    src_lens = ToLengths(src_paddings)
    tgt_lens = ToLengths(tgt_paddings)

    with plot.MatplotlibFigureSummary(name + '/Attention',
                                      max_outputs=max_outputs,
                                      gridspec_kwargs={'hspace': 0.3}) as fig:
        for n, atten in enumerate(attention_tensors):
            # Diagnostic metric that decreases as attention picks up.
            max_entropy = tf.math.log(tf.cast(Get(src_lens, n), tf.float32))
            max_entropy = tf.expand_dims(tf.expand_dims(max_entropy, -1), -1)
            atten_normalized_entropy = -atten * tf.math.log(
                atten + 1e-10) / max_entropy
            scalar(name + '/Attention/average_normalized_entropy/%d' % n,
                   tf.reduce_mean(atten_normalized_entropy))
            args = [atten, Get(src_lens, n), Get(tgt_lens, n)]
            if transcripts is not None and n == 0:
                args.append(transcripts)
            fig.AddSubplot(args,
                           TrimPaddingAndPlotAttention,
                           title=GetTensorName(atten, name, n),
                           xlabel='Input',
                           ylabel='Output')