Esempio n. 1
0
    def _finish(self, update_ops, name_scope):
        with tf.control_dependencies(update_ops):
            ops1 = self.magnitude_optimizer._finish([], name_scope + "_m")  # pylint: disable=protected-access
            ops2 = self.direction_optimizer._finish([], name_scope + "_d")  # pylint: disable=protected-access

            if self.use_global_norm:  # apply global grafting
                with tf.control_dependencies([ops1, ops2]):
                    m_global_norm = tf.Variable(0.)
                    d_global_norm = tf.Variable(0.)
                    for var in self._variables:
                        m_step_norm = self.get_slot(var, "m_step_norm")
                        d_step_norm = self.get_slot(var, "d_step_norm")
                        tf.assign_add(m_global_norm, m_step_norm**2)
                        tf.assign_add(d_global_norm, d_step_norm**2)

                    multiplier = tf.sqrt(m_global_norm /
                                         tf.maximum(d_global_norm, 1e-30))

                    step_ops = []
                    for var in self._variables:
                        d_step = self.get_slot(var, "scratch_copy")
                        step = tf.where(tf.greater(d_step_norm, 0),
                                        multiplier * d_step,
                                        tf.zeros_like(d_step))
                        step_op = tf.assign_add(
                            var, self._learning_rate_tensor * step)
                        step_ops.append(step_op)
                    return tf.group(*step_ops, name=name_scope)

        return tf.group(*([ops1, ops2] + update_ops), name=name_scope)
Esempio n. 2
0
  def _BPropForVariables(self, vmap):
    """Constructs the backward graph."""
    train_ops = self._BPropGenTrainOps(vmap)

    # TODO(rpang): try to structure _train_op as:
    #   tf.cond(skip_step, <only update skip stats>, <all updates>)
    # so that we skip all other updates when a step is skipped.
    with tf.control_dependencies(
        [tf.group(*tf.nest.flatten(train_ops), name='train_ops')]):
      self._train_op = tf.group(self._post_train_ops, name='bprop')
Esempio n. 3
0
    def __init__(self,
                 inference_graph,
                 subgraph_name=None,
                 checkpoint=None,
                 device_type="gpu",
                 tf_master="",
                 session_config=None,
                 clear_device_placement=False):
        assert device_type in ["cpu", "gpu", "tpu"]
        subgraph_name = subgraph_name or "default"
        if isinstance(inference_graph, six.string_types):
            tf.logging.info("Reading inference graph from %s.",
                            inference_graph)
            inference_graph = LoadInferenceGraph(inference_graph,
                                                 clear_device_placement)
        self._inference_graph = inference_graph
        self._checkpoint = checkpoint
        self._device_type = device_type
        self._tf_master = tf_master
        self._session_config = session_config

        self._graph = tf.Graph()
        with self._graph.as_default():
            tf.logging.info(
                "Loading inference graph for prediction subgraph_name={}.".
                format(subgraph_name))
            self._saver = tf.train.Saver(saver_def=inference_graph.saver_def)
            with tf.device("/%s:0" %
                           "cpu" if device_type == "tpu" else device_type):
                tf.import_graph_def(inference_graph.graph_def, name="")
            if device_type == "tpu":
                # If no tpu init op exists, create it here.
                try:
                    self._graph.get_operation_by_name("tpu_init_op")
                except KeyError:
                    tf.group(tf.tpu.initialize_system(), name="tpu_init_op")

            self._graph.finalize()

        if inference_graph.subgraphs:
            if subgraph_name not in inference_graph.subgraphs:
                raise ValueError(
                    "Subgraph %s not defined. Valid subgraphs: %s" %
                    (subgraph_name, list(inference_graph.subgraphs.keys())))
            subgraph = inference_graph.subgraphs[subgraph_name]
            self._fetches = subgraph.fetches
            self._feeds = subgraph.feeds
        else:
            self._fetches = inference_graph.fetches
            self._feeds = inference_graph.feeds

        # Lock for creating new sessions.
        self._sess_lock = threading.Lock()
        self._cur_sess_id = 0
        self._CreateNewSession()
Esempio n. 4
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        if self._num_micro_batches == 1:
            return self._opt.apply_gradients(grads_and_vars, global_step)
        global_step = global_step or py_utils.GetOrCreateGlobalStepVar()
        with tf.init_scope():
            self._create_slots([v for (_, v) in grads_and_vars])

        accums = []
        variables = []

        for g, v in grads_and_vars:
            accum = self.get_slot(v, 'grad_accum')
            variables.append(v)
            # pytype: disable=attribute-error
            if isinstance(g, tf.IndexedSlices):
                scaled_grad = tf.IndexedSlices(g.values /
                                               self._num_micro_batches,
                                               g.indices,
                                               dense_shape=g.dense_shape)
            else:
                scaled_grad = g / self._num_micro_batches
            accum_tensor = accum.read_value()
            accums.append(accum.assign(accum_tensor + scaled_grad))
            # pytype: enable=attribute-error

        def _ApplyAndReset():
            normalized_accums = accums
            if self._apply_crs_to_grad:
                normalized_accums = [
                    tf.tpu.cross_replica_sum(accum.read_value())
                    for accum in accums
                ]
            apply_op = self._opt.apply_gradients(
                list(zip(normalized_accums, variables)))
            with tf.control_dependencies([apply_op]):
                zero_op = [
                    tf.assign(accum, tf.zeros_like(accum)) for accum in accums
                ]
            return tf.group(zero_op, tf.assign_add(global_step, 1))

        def _Accum():
            return tf.no_op()

        accum_step = tf.cond(
            tf.equal(
                tf.math.floormod(self._counter + 1, self._num_micro_batches),
                0),
            _ApplyAndReset,  # Apply the accumulated gradients and reset.
            _Accum)  # Accumulate gradients.

        with tf.control_dependencies([tf.group(accums)]):
            return tf.group(accum_step, tf.assign_add(self._counter, 1))
Esempio n. 5
0
def partitioned_variable_assign(partitioned_var, new_value):
    """Assign op for partitioned variables.

  Args:
    partitioned_var: A partitioned tensorflow variable
    new_value: Value to be assigned to the variable var

  Returns:
    A tensorflow op that groups the assign ops for each of the variable slices
  """
    # Determine which axis was used to partition the variable. Currently
    # tensorflow allows partitioning variable only along 1 axis.
    axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis(
        partitioned_var)

    partition_sizes = np.array(
        [partition.get_shape()[axis] for partition in partitioned_var])
    new_partitioned_values = tf.split(new_value,
                                      tf.convert_to_tensor(partition_sizes,
                                                           dtype=tf.int32),
                                      axis=axis)
    op_list = []
    for partition in partitioned_var:
        op_list.append(
            variable_assign(partition, new_partitioned_values[len(op_list)]))
    return tf.group(*op_list, name=partitioned_var.name + '_group_assign')
Esempio n. 6
0
  def Apply(self, lr, var_grad):
    p = self.params

    def _Acc(vg):
      """Updating accumulators."""

      v, g = vg
      with tf.variable_scope(v.op.name):
        a = py_utils.CreateVariable(
            'grad_accumulator',
            py_utils.WeightParams(v.get_shape(),
                                  py_utils.WeightInit.Constant(0.0),
                                  self.params.dtype),
            trainable=False)
        a = tf.assign_add(a, g)

      return py_utils.VarGrad(v, a)

    var_grad = var_grad.Transform(_Acc)

    def _ApplyAndReset():
      with tf.control_dependencies([
          self._opt.Apply(
              lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps))
      ]):
        return tf.group(
            *[tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten()])

    if self.params.add_summary_in_apply:
      self.AddSummary(lr, self.GetOptimizer(lr), var_grad)
    return tf.cond(
        tf.equal(
            tf.math.floormod(self.global_step, p.accum_steps),
            p.accum_steps - 1), _ApplyAndReset, lambda: tf.group(tf.no_op()))
Esempio n. 7
0
    def _TestHelper(self, params, test_data, enroll_data):
        """Returns the attentive scores for the given test and enrollment data.

    Args:
      params: Babelfish configuration parameters for setting up the
        attentive_scoring_layer.
      test_data: Test data related to 2 test utterances each with 2 key vectors
        of 2 dimensions and 2 value vectors of 3 dimensions. Each utterance
        representation contains a packed form of the key and value vectors. The
        result is a 2d tensor (of tf.float32 elements) of dimension
        [num_test_utts, representation_dim]. In this example, num_test_utts is 2
        and representation_dim is 10 (or 2 keys * (2 key_dim + 3 value_dim)).
      enroll_data: Enrollment data related to 2 speakers each with 2 enrollment
        utterances. Each utterance is composed of 2 key vectors of 2 dimensions
        and 2 value vectors of 3 dimensions. Each utterance is a packed form of
        the key and value vectors. The result is a 3d tensor (of tf.float32
        elements) of dimension [num_enroll_spks, num_enroll_utts_per_spk,
        representation_dim].

    Returns:
      The output of the attentive scoring. The result is a numpy np.float32
      tensor of shape [num_test_utts, num_enroll_spks].
    """

        with self.session() as sess:
            tf.random.set_seed(_TF_RANDOM_SEED)
            attention_network = params.Instantiate()

            output = attention_network.FProp((test_data, enroll_data))

            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.tables_initializer()))

            return sess.run(output)
Esempio n. 8
0
 def _ApplyAndReset():
   with tf.control_dependencies([
       self._opt.Apply(
           lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps))
   ]):
     return tf.group(
         *[tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten()])
Esempio n. 9
0
 def _Apply():
     """Use the matched optimizer to apply the gradients."""
     train_ops = []
     non_default_regex = [
         regex for regex in self._optimizer_map
         if regex != 'default_optimizer'
     ]
     for regex in self._optimizer_map:
         if var_grad_map[regex]:
             opt = tf_optimizer_map[regex]
             train_ops.append(opt.apply_gradients(var_grad_map[regex]))
             # pylint: disable=cell-var-from-loop, g-long-lambda
             if regex == 'default_optimizer':
                 filtered_var_grad = var_grad.FilterKeyVal(
                     lambda k, v: any([
                         re.match(i, v.var.name)
                         for i in non_default_regex
                     ]))
             else:
                 filtered_var_grad = var_grad.FilterKeyVal(
                     lambda k, v: (re.match(regex, v.var.name)))
             # pylint: enable=cell-var-from-loop, g-long-lambda
             self._optimizer_map[regex].AddSummary(
                 self._lr_map[regex], opt, filtered_var_grad)
     return tf.group(*train_ops, name='composite_optimizer_train_op')
Esempio n. 10
0
    def testAttentionNetworkFPropTrainableScaleFactor(self):
        """Checks that the forward propagation is correct for trainable scaling."""

        # Enable trainable scaling
        self.params.use_trainable_scale_factor = True
        self.params.scale_factor = 2.0

        with self.session() as sess:
            tf.random.set_seed(_TF_RANDOM_SEED)
            attention_network = self.params.Instantiate()

            output = attention_network.FProp(
                (self.test_data, self.enroll_data), attention_network.theta)
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.tables_initializer()))
            output_result = sess.run(output)
            log_scale_factor_result = sess.run(
                attention_network.theta.trainable_log_scale_factor)

        expected_output = np.array([[1.0, 0.434889], [-0.269374, -0.202443]])
        self.assertAllClose(expected_output,
                            output_result,
                            rtol=_REL_TOLERANCE,
                            atol=_ABS_TOLERANCE)

        # Check that the log of the scale_factor=2 is as expected
        self.assertAllClose(0.6931471824645996,
                            log_scale_factor_result,
                            rtol=_REL_TOLERANCE,
                            atol=_ABS_TOLERANCE)
Esempio n. 11
0
 def PostTrainingStepUpdate(self, global_step):
     ops = [
         super(PassiveAsymQDomain, self).PostTrainingStepUpdate(global_step)
     ]
     for t_name in self._t_names:
         ops.extend(self._RecordTensor(t_name))
         self._SummarizeTensor(t_name)
     return tf.group(ops)
Esempio n. 12
0
    def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
Esempio n. 13
0
  def PostTrainingLoop(self, outfeed=None):
    """Construct the post training loop op.

    Args:
      outfeed: a dict of tensors dequeued from TPU outfeed queue.
    """
    with py_utils.GlobalStepContext(self._global_step_var):
      self._post_training_loop_op = tf.group(
          *[opt.ApplyPostTrainingLoop() for opt in self.learners])
Esempio n. 14
0
    def ApplyPostTrainingLoop(self):
        """Applies any computation to run after each tpu trainining loop.

    Returns:
      Ops to run after training loop ends.
    """
        invoke_async_ops = self._optimizer.invoke_async_preconditioner_computation(
            tf.cast(py_utils.GetGlobalStep(), tf.int32))
        assign_ops = self._optimizer.assign_preconditioner_to_host_vars()
        return tf.group(*[invoke_async_ops, assign_ops])
Esempio n. 15
0
  def ApplyPostTrainingLoop(self):
    """Apply computation to run after each tpu training loop for each optimizer.

    Returns:
      Ops to run after training loop ends.
    """
    post_training_ops = [
        opt.ApplyPostTrainingLoop() for _, opt in self._optimizer_map.items()
    ]
    return tf.group(*post_training_ops)
Esempio n. 16
0
    def Apply(self, metrics, vmap, gradient_mask=None, gradient_adjuster=None):
        """Computes updates on 'vmap' to optimize 'loss'.

    TODO(rpang): explore merging gradient_mask and gradient_adjuster.

    Args:
      metrics: A Dict[str, (value, weight)], from which loss can be extracted
        according to p.loss_name.
      vmap: A `.NestedMap` object containing variables to optimize.
      gradient_mask: if not None, a dict mapping variable names to a 0/1 scalar.
      gradient_adjuster: if not None, a function that mutates a given var_grads.

    Returns:
      (losses, op, eval_metrics), where
        - losses is a list of scalar tensors;
        - op is a tf.Operation to update variables;
        - eval_metrics is a Dict[str, (value, weight)], where each value/weight
          is a scalar tensor.
    """
        # We apply gradients outside the name_scope to maintain backwards
        # compatibility on variables created by self.optimizer.Apply().
        losses, var_grads, eval_metrics = self._ComputeLossesAndGradients(
            metrics, vmap)
        if 'tpu_embedding_var_grads' in var_grads:
            tpu_embedding_var_grads = var_grads.tpu_embedding_var_grads
            del var_grads.tpu_embedding_var_grads

            tpu_embedding_collection = py_utils.GetTpuEmbeddingGraphCollection(
            )[0]
            assert tpu_embedding_collection
            tpu_emb_update_op, stats = tpu_embedding_collection.ApplyGradients(
                py_utils.GetTaskCallScope(),
                tpu_embedding_var_grads.Transform(
                    lambda var_grad: var_grad.grad))
            eval_metrics.update(stats)
        else:
            tpu_emb_update_op = tf.no_op()

        assert py_utils.GetGlobalStep() is not None
        lr = self.LearningRate()

        var_grads, stats = self.AdjustGradients(
            var_grads,
            gradient_mask=gradient_mask,
            gradient_adjuster=gradient_adjuster)
        eval_metrics.update(stats)
        self._var_grads = var_grads

        eval_metrics['learning_rate'] = (tf.convert_to_tensor(lr),
                                         tf.convert_to_tensor(1.))

        var_update_op = tf.group(
            [tpu_emb_update_op,
             self.optimizer.Apply(lr, var_grads)])
        return losses, var_update_op, eval_metrics
Esempio n. 17
0
 def _ApplyAndReset():
   normalized_accums = accums
   if self._apply_crs_to_grad:
     normalized_accums = [
         tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums
     ]
   apply_op = self._opt.apply_gradients(
       list(zip(normalized_accums, variables)))
   with tf.control_dependencies([apply_op]):
     zero_op = [tf.assign(accum, tf.zeros_like(accum)) for accum in accums]
   return tf.group(zero_op, tf.assign_add(global_step, 1))
Esempio n. 18
0
  def PostTrainingStepUpdate(self):
    """Returns a TF op which will be invoked at each training step.

    Subclasses of `BaseLayer` can implement this method. The method should
    return a TF op to be invoked during training after gradients are applied.
    """
    update_ops = [
        child.PostTrainingStepUpdate()
        for child in self._private_children.Flatten()
    ]
    return tf.group(*update_ops)
Esempio n. 19
0
 def _BuildRestore(self):
   """Builds restore ops."""
   assign_ops = []
   for var in self._vars:
     val, = io_ops.restore_v2(
         prefix=self._restore_prefix_ph,
         tensor_names=[_VarKey(var)],
         shape_and_slices=[""],
         dtypes=[var.dtype])
     assign_ops.append(var.assign(val))
   self._restore_op = tf.group(*assign_ops)
Esempio n. 20
0
    def ApplyPostTrainingLoop(self, global_step):
        """Applies any computation to run after each tpu trainining loop.

    Args:
      global_step: Global step variable.

    Returns:
      Ops to run after training loop ends.
    """
        invoke_async_ops = self._optimizer.invoke_async_preconditioner_computation(
            tf.cast(global_step, tf.int32))
        assign_ops = self._optimizer.assign_preconditioner_to_host_vars()
        return tf.group(*[invoke_async_ops, assign_ops])
Esempio n. 21
0
    def ApplyPostTrainingLoop(self, global_step):
        """Apply any computation to run after each tpu training loop for each optimizer.

    Args:
      global_step: Global step variable.

    Returns:
      Ops to run after training loop ends.
    """
        post_training_ops = [
            opt.ApplyPostTrainingLoop(global_step)
            for _, opt in self._optimizer_map.items()
        ]
        return tf.group(*post_training_ops)
Esempio n. 22
0
    def assign_preconditioner_to_host_vars(self):
        """Assign/Grab latest copy of preconditioners."""
        keys_shapes_and_preconditioner_vars = []
        assign_ops = []
        for var in self._all_vars_for_preconditioning:
            shape = var.get_shape()
            if not self._fallback_to_diagonal_for_shape(shape):
                partitioned_v = TensorPartitioner.partition_tensor(
                    var, self._partition_info)
                num_partitions = len(partitioned_v)
                for pt_idx, pt in enumerate(partitioned_v):
                    pt_shape = pt.get_shape()
                    preconditioner_exists_for_dim = (
                        self._preconditioner_available_for_dims(pt_shape))
                    var_rank = len(pt_shape)
                    for i in range(var_rank):
                        if preconditioner_exists_for_dim[i]:
                            key = self._key_for_var(var, i, pt_idx)
                            preconditioner = self.get_slot(
                                var,
                                self._preconditioner_key_for_partition_and_dim(
                                    i, pt_idx, num_partitions))
                            keys_shapes_and_preconditioner_vars.append(
                                (key, tf.shape(preconditioner),
                                 preconditioner))

            if not keys_shapes_and_preconditioner_vars:
                return tf.no_op()

            keys, shapes, preconditioner_vars = zip(
                *keys_shapes_and_preconditioner_vars)

            preconditioner_vals, successes = x_ops.get_preconditioners(
                shapes,
                keys=keys,
                preconditioner_compute_graphdef=(
                    self._preconditioner_compute_graphdef))

            for preconditioner_var, preconditioner_val, success in zip(
                    preconditioner_vars, preconditioner_vals, successes):
                success_mult = tf.cast(success, preconditioner.dtype)
                assign_ops.append(
                    state_ops.assign(
                        preconditioner_var,
                        (1.0 - success_mult) * preconditioner_var +
                        success_mult * preconditioner_val))
        return tf.group(*assign_ops)
Esempio n. 23
0
    def _BuildRestore(self):
        """Builds restore ops."""
        assign_ops = []
        with self._var_graph.as_default():
            per_device = collections.defaultdict(lambda: [])
            for var in self._vars:
                per_device[var.device].append(var)

            for device, var_list in per_device.items():
                with self._var_graph.device(device):
                    for var in var_list:
                        val, = io_ops.restore_v2(
                            prefix=self._restore_prefix_ph,
                            tensor_names=[_VarKey(var)],
                            shape_and_slices=[""],
                            dtypes=[var.dtype])
                        assign_ops.append(var.assign(val))

        self._restore_op = tf.group(*assign_ops)
Esempio n. 24
0
 def PostTrainingStepUpdate(self, global_step):
     """Updates moving_mean, moving_variance after each training step."""
     p = self.params
     # Get sufficient stats that accumulates over microbatches.
     counts = self.accumulators.counts.GetValue()
     mean_ss = self.accumulators.mean_ss.GetValue()
     variance_ss = self.accumulators.variance_ss.GetValue()
     # Compute batch mean and batch variance from sufficient stats
     mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss,
                                              None)
     decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype)
     # Update moving_mean, moving_variance from  batch mean and batch variance.
     with tf.name_scope(p.name) as scope:
         with tf.colocate_with(self.vars.moving_mean):
             mean_update = tf.assign_sub(
                 self.vars.moving_mean,
                 tf.where(tf.greater(counts, 0.5),
                          (self.vars.moving_mean - tf.cast(mean, p.dtype)) *
                          decay, tf.zeros_like(self.vars.moving_mean)),
                 name='moving_mean_update')
         with tf.colocate_with(self.vars.moving_variance):
             var_update = tf.assign_sub(
                 self.vars.moving_variance,
                 tf.where(tf.greater(counts, 0.5),
                          (self.vars.moving_variance -
                           tf.cast(variance, p.dtype)) * decay,
                          tf.zeros_like(self.vars.moving_variance)),
                 name='moving_variance_update')
         py_utils.CheckNumerics(
             self.vars.moving_mean,
             'moving mean of {} failed numeric check'.format(scope))
         py_utils.CheckNumerics(
             self.vars.moving_variance,
             'moving variance of {} failed numeric check'.format(scope))
     self.accumulators.counts.Reset()
     self.accumulators.mean_ss.Reset()
     self.accumulators.variance_ss.Reset()
     return tf.group(mean_update, var_update)
Esempio n. 25
0
    def _TestHelperWithState(self, params, list_of_batches):
        """Returns the expected outputs for the tests.

    Args:
      params: Babelfish configuration parameters for setting up the
        cumulative_statistics_layer.
      list_of_batches: A list of padded batches of examples.
        The structure is a list of the following: {
        'features': tf.tensor(float32) of shape(len, batch, dim)
        'paddings': tf.tensor(float32) of shape(len, batch) }

    Returns:
      A dictionary containing numpy arrays of the expected test outputs.
      The structure is as follows:
      {
        'features': np.array(float32) of shape(len, batch, dim)
        'paddings': np.array(float32) of shape(len, batch)
      }
    """

        with self.session() as sess:
            tf.random.set_seed(_TF_RANDOM_SEED)
            network = params.Instantiate()

            batch_size = list_of_batches[0].features.shape[1]
            state = network.zero_state(network.theta, batch_size)

            for batch_t in list_of_batches:
                output = network.FProp(network.theta, batch_t, state)
                # Pass the output state over to the next batch as input state.
                state = output.state

            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.tables_initializer()))

            return sess.run(output)
Esempio n. 26
0
  def _BPropForVariables(self, vmap):
    """Constructs the backward graph."""
    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in six.iteritems(self._per_input_gradient_mask)
      }
    all_losses = []
    for optimization in self.learners:
      loss_name = optimization.params.name
      metric = self._metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(self._metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      for key, (value, weight) in six.iteritems(eval_metrics):
        self.AddEvalMetric(key + '/' + loss_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    # Get the op to update the weight masks and thresholds
    train_ops['mask_updates'] = self._GetMaskUpdateOp()

    # Post training step update.
    train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step)

    with tf.control_dependencies(tf.nest.flatten(train_ops)):
      true_global_step = py_utils.GetOrCreateGlobalStepVar()
      with tf.colocate_with(true_global_step):
        increment_global_steps = tf.assign_add(true_global_step, 1)
      if self._global_step_var != true_global_step:
        with tf.colocate_with(self._global_step_var):
          increment_global_steps = tf.group(
              increment_global_steps, tf.assign_add(self._global_step_var, 1))
      train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in six.iteritems(train_ops):
      assert op is not None, op_name

    # TODO(rpang): try to structure _train_op as:
    #   tf.cond(skip_step, <only update skip stats>, <all updates>)
    # so that we skip all other updates when a step is skipped.
    self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
Esempio n. 27
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Esempio n. 28
0
    def CreateTpuEnqueueOps(self):
        """Create the host-side enqueue ops.

    This should be called in an outer non-TPU context.
    """
        assert not self._tpu_queues, (
            'CreateTpuEnqueueOps should only be called '
            'once.')
        self._tpu_queues = []
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTpuEnqueueOps num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        shards = (cluster.total_worker_devices //
                  num_infeed_hosts) // cluster.num_devices_per_split
        tf.logging.info('shards {}'.format(shards))

        input_ops_list = []
        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')

        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                self._batch = self.GetPreprocessedInputBatch()
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                for k, x in self._batch.FlattenItems():
                    assert x.shape.is_fully_defined(), (
                        'Shape must be fully defined: %s: %s' % (k, x))
                    # TODO(cwhipkey): if it's a string (or other type not supported on
                    # TPU), drop it from feeding and on the other end add in an op that
                    # fails if used.
                shapes = self._batch.Transform(lambda x: x.shape).Flatten()
                dtypes = self._batch.Transform(lambda x: x.dtype).Flatten()

                tf.logging.info('host_device: %s infeed shapes: %r',
                                host_device, shapes)
                tf.logging.info('host_device: %s infeed dtypes: %r',
                                host_device, dtypes)

                if p.use_partitioned_infeed_queue:
                    device_assignment = py_utils.GetTpuDeviceAssignment()

                    host_device = device_assignment.host_device(
                        replica=0, job=tf.flags.FLAGS.tf_master)
                    host_id = int(
                        host_device.split('/task:')[1].split('/device:')[0])
                    tf.logging.info('host_id: {} host_device: {}'.format(
                        host_id, host_device))
                    q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                        number_of_tuple_elements=len(dtypes),
                        device_assignment=device_assignment,
                        host_id=host_id,
                        input_partition_dims=[[p.num_partitions] + [1] *
                                              (len(s) - 1) for s in shapes],
                        tuple_types=dtypes,
                        tuple_shapes=shapes)
                else:
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                self._tpu_queues.append(q)

                if p.use_partitioned_infeed_queue:
                    input_ops = q.generate_enqueue_ops([self._batch.Flatten()])
                elif p.use_per_host_infeed:
                    # TODO(ylc/zhifengc): Add this to a policy module and test it.
                    def TPUOrdinalFunction(shard_index_in_host):
                        device_assignment = py_utils.GetTpuDeviceAssignment()
                        if device_assignment:
                            # We put both enqueue/dequeue ops at core 0 in each replica.
                            replica = device_assignment.lookup_replicas(
                                task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                            return device_assignment.tpu_ordinal(
                                replica=replica)
                        else:
                            return shard_index_in_host

                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                        tpu_ordinal_function=TPUOrdinalFunction)
                else:
                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        device_assignment=py_utils.GetTpuDeviceAssignment())
                input_ops_list += input_ops

        tf.logging.info('input_ops_list %s', input_ops_list)
        grouped_infeed_op = tf.group(*input_ops_list)
        self._tpu_infeed_op = []
        for _ in range(p.tpu_infeed_parallelism):
            self._tpu_infeed_op.append(grouped_infeed_op)
Esempio n. 29
0
  def Export(cls,
             model_cfg,
             model_task_name=None,
             device_options=InferenceDeviceOptions(
                 device='',
                 retain_device_placement=False,
                 var_options=None,
                 gen_init_op=True,
                 dtype_override=None),
             freeze_checkpoint=None,
             freeze_defaults=False,
             export_path=None,
             subgraph_filter=None,
             random_seed=None,
             disable_packed_input=True):
    """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
    assert issubclass(model_cfg.cls, base_model.BaseModel)

    # Disable assertions unless user explicitly enables it.
    if FLAGS['enable_asserts'].using_default_value:
      FLAGS.enable_asserts = False

    # TODO(laurenzo): Work out how much we need to specify here in terms of
    # cluster configuration.
    cls._SetClusterParams(model_cfg.cluster, device_options)

    # Configure the model.
    model_cfg.random_seed = random_seed
    model_cfg.is_inference = True

    if disable_packed_input:

      def _DisablePackedInput(task):
        if (_ParamExists(task, 'encoder') and
            _ParamExists(task.encoder, 'packed_input')):
          task.encoder.packed_input = False
        if (_ParamExists(task, 'decoder') and
            _ParamExists(task.decoder, 'packed_input')):
          task.decoder.packed_input = False

      if issubclass(model_cfg.cls, base_model.MultiTaskModel):
        for _, task_param in model_cfg.task_params.IterParams():
          _DisablePackedInput(task_param)
      else:
        _DisablePackedInput(model_cfg.task)

    tf.logging.info('Model %s params:', model_cfg.name)
    for line in model_cfg.ToText().split('\n'):
      tf.logging.info('%s', line)

    # Instantiate the graph.
    graph = tf.Graph()
    with graph.as_default():
      tf.random.set_seed(random_seed)
      cluster = model_cfg.cluster.Instantiate()
      device = cluster.GetPlacer()
      tpu_const_scope = _DummyScope()
      if (IsTpu(device_options) and
          device_options.var_options == 'AS_CONSTANTS'):
        # Do not specify devices for variables if we are marking them as
        # constants.
        device = ''
        tpu_const_scope = ConstGuaranteeScope()

      with cluster, tf.device(device), tpu_const_scope:

        bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
            device_options)

        if bfloat16_override:
          py_utils.UpdateDtype(model_cfg, tf.bfloat16)
          py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

        # Hard-code TPU-related flags prior to instantiating model.
        old_enable_asserts = FLAGS.enable_asserts
        old_xla_device = FLAGS.xla_device
        if IsTpu(device_options):
          FLAGS.enable_asserts = False
          FLAGS.xla_device = 'tpu'

        # Ensure the global_step variable is created.
        _ = py_utils.GetOrCreateGlobalStepVar()
        try:
          mdl = model_cfg.Instantiate()
          task = mdl.GetTask(model_task_name)

          variables_to_restore = (
              _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else
              mdl.ema.variables_to_restore(mdl.variables_for_ema))

          if bfloat16_override:
            saver_var_spec = (
                bfloat16_variables
                .get_saver_spec_for_variables_with_bf16_overrides(
                    variables_to_restore))
          else:
            saver_var_spec = variables_to_restore

          saver = tf.train.Saver(saver_var_spec)
          tf.variables_initializer(
              tf.global_variables(), name='init_all_variables')
          if IsTpu(device_options) and device_options.gen_init_op:
            tf.group(tf.tpu.initialize_system(), name='tpu_init_op')

          inference_graph_proto = inference_graph_pb2.InferenceGraph()
          subgraphs_proto = task.Inference()
          if isinstance(subgraphs_proto, dict):
            subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
          for name, subgraph in subgraphs_proto.subgraphs.items():
            if not subgraph_filter or name in subgraph_filter:
              inference_graph_proto.subgraphs[name].CopyFrom(subgraph)

          # Add a table init op and global variable init op to the graph.
          # Tables can be declared anywhere in the graph, so this op has to be
          # added last.
          tf.tables_initializer(name='init_all_tables')
        finally:
          # Reset TPU-related flags after model instantiation.
          FLAGS.enable_asserts = old_enable_asserts
          FLAGS.xla_device = old_xla_device

    tf.logging.info('Graph contains ops: %r',
                         [op.name for op in graph.get_operations()])

    inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

    # Freezing.
    if freeze_defaults or freeze_checkpoint:
      output_op_names = GetOutputOpNames(
          graph, inference_graph_proto, preserve_colocation_nodes=False)
      if cls._DeviceSupportsFreezing(device_options):
        raise ValueError('freeze_checkpoint cannot be used with device ' +
                         device_options.device)
      if freeze_checkpoint:
        tf.logging.info('Freezing graph from checkpoint: %s',
                             freeze_checkpoint)
        graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
                                               output_op_names)
      elif freeze_defaults:
        tf.logging.info('Default initializing graph and freezing.')
        graph_def = _FreezeDefaults(graph, output_op_names)
    else:
      output_op_names = GetOutputOpNames(graph, inference_graph_proto)

      # Prune the graph to just the parts we need.
      # To support restoring, we have to not prune out the restore node.
      output_op_names.append('init_all_tables')
      output_op_names.append('init_all_variables')
      output_op_names.append('save/control_dependency')
      output_op_names.append('save/restore_all')
      if IsTpu(device_options) and device_options.gen_init_op:
        output_op_names.append('tpu_init_op')
      graph_def = graph.as_graph_def()
      tf.logging.info('Pruning graph to output ops: %r', output_op_names)
      graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names)

    if not device_options.retain_device_placement:
      # Clear the device so that the runtime can choose.
      tf.logging.info('Clearing device placement for: %s',
                           device_options.device)
      for node in graph_def.node:
        node.ClearField('device')
      for function in graph_def.library.function:
        for node_def in function.node_def:
          node_def.ClearField('device')

    inference_graph_proto.graph_def.CopyFrom(graph_def)

    if export_path:
      with tf.io.gfile.GFile(export_path, 'w') as f:
        f.write(text_format.MessageToString(inference_graph_proto))
    return inference_graph_proto
Esempio n. 30
0
    def _resource_apply_dense(self, grad, var):
        if grad is None:
            tf.logging.warning('Gradient is None for variable %s' % var.name)
            return []

        grad_dtype = var.dtype  # TODO(lepikhin): add to params
        grad = tf.cast(grad, grad_dtype)
        factored_dims = self._factored_dims(var.shape.as_list())
        if factored_dims:
            vr = self.get_slot(var, 'vr')
            vc = self.get_slot(var, 'vc')
        else:
            v = self.get_slot(var, 'v')
        if self._beta1:
            m = self.get_slot(var, 'm')

        cond = tf.constant(True)

        def _Upd(c, x):
            if not self._cond_is_finite:
                return c
            c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x)))
            c = tf.math.logical_and(
                c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x))))
            return c

        def _Wrap(fn, x, y):
            if not self._cond_is_finite:
                return fn(x, y)
            return tf.cond(cond, lambda: fn(x, y), lambda: x)

        with tf.variable_scope(var.name[:-2] + '/Adafactor'):
            grad_squared = tf.math.square(grad) + tf.cast(
                self._epsilon1, grad_dtype)
            cond = _Upd(cond, grad_squared)
            decay_rate = tf.cast(self._decay_rate, var.dtype)
            old_val = tf.identity(
                var)  # TODO(lepikhin): introduce gradient dtype
            lr = GetLrValue(self._learning_rate)
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(old_val) * tf.cast(
                    lr, grad_dtype)
            else:
                update_scale = lr
            mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype)
            update_scale = tf.cast(update_scale, grad_dtype)
            updates = []
            if factored_dims:
                d0, d1 = factored_dims
                vr_axis, vc_axis = d0, d1
                grad_squared_row_mean = tf.reduce_mean(grad_squared,
                                                       axis=vr_axis)
                grad_squared_col_mean = tf.reduce_mean(grad_squared,
                                                       axis=vc_axis)
                # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                cond = _Upd(cond, new_vr)
                cond = _Upd(cond, new_vc)
                vr_update = _Wrap(tf.assign, vr, new_vr)
                vc_update = _Wrap(tf.assign, vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
                r_factor = tf.math.rsqrt(new_vr / long_term_mean)
                c_factor = tf.math.rsqrt(new_vc)
                x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims(
                    c_factor, vc_axis)
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                cond = _Upd(cond, new_v)
                v_update = _Wrap(tf.assign, v, new_v)
                updates.append(v_update)
                x = grad * tf.math.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = tf.maximum(
                    tf.constant(1.0, grad_dtype),
                    py_utils.ReduceRms(x) /
                    tf.constant(self._clipping_threshold, grad_dtype))
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) +
                         subtrahend *
                         tf.constant(1.0 - self._beta1, dtype=grad_dtype))
                subtrahend = new_m
                cond = _Upd(cond, new_m)
                updates.append(_Wrap(tf.assign, m, new_m))
            # It is critical to use assign_sub instead of tf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            cond = _Upd(cond, subtrahend)
            var_update = _Wrap(tf.assign_sub, var, subtrahend)
            updates.append(var_update)
            return tf.group(*updates)