Esempio n. 1
0
    def _finish(self, update_ops, name_scope):
        with tf.control_dependencies(update_ops):
            ops1 = self.magnitude_optimizer._finish([], name_scope + "_m")  # pylint: disable=protected-access
            ops2 = self.direction_optimizer._finish([], name_scope + "_d")  # pylint: disable=protected-access

            if self.use_global_norm:  # apply global grafting
                with tf.control_dependencies([ops1, ops2]):
                    m_global_norm = tf.Variable(0.)
                    d_global_norm = tf.Variable(0.)
                    for var in self._variables:
                        m_step_norm = self.get_slot(var, "m_step_norm")
                        d_step_norm = self.get_slot(var, "d_step_norm")
                        tf.assign_add(m_global_norm, m_step_norm**2)
                        tf.assign_add(d_global_norm, d_step_norm**2)

                    multiplier = tf.sqrt(m_global_norm /
                                         tf.maximum(d_global_norm, 1e-30))

                    step_ops = []
                    for var in self._variables:
                        d_step = self.get_slot(var, "scratch_copy")
                        step = tf.where(tf.greater(d_step_norm, 0),
                                        multiplier * d_step,
                                        tf.zeros_like(d_step))
                        step_op = tf.assign_add(
                            var, self._learning_rate_tensor * step)
                        step_ops.append(step_op)
                    return tf.group(*step_ops, name=name_scope)

        return tf.group(*([ops1, ops2] + update_ops), name=name_scope)
Esempio n. 2
0
    def _Apply2(proj_layer, opt):
      inputs1 = np_input1
      output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1)
      loss2_1 = tf.reduce_sum(output1)
      var_grads2_1 = py_utils.ComputeGradients(loss2_1, proj_layer.vars)
      grads2_1 = var_grads2_1.Transform(tuple)

      inputs1 = np_input2
      output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1)
      loss2_2 = tf.reduce_sum(output1)
      var_grads2_2 = py_utils.ComputeGradients(loss2_2, proj_layer.vars)
      grads2_2 = var_grads2_2.Transform(tuple)

      with cluster_factory.ForTestingWorker(add_summary=True):
        _ = opt.Apply(lr, var_grads2_1)

      # Get `snapshots` of the intermediate variables
      vars2_intermediate = [v.read_value() for v in proj_layer.vars.Flatten()]
      tf.assign_add(py_utils.GetOrCreateGlobalStepVar(), 1)

      with cluster_factory.ForTestingWorker(add_summary=True):
        _ = opt.Apply(lr, var_grads2_2)

      vars2_1 = proj_layer.vars.Flatten()

      return vars2_intermediate, vars2_1, grads2_1, grads2_2
Esempio n. 3
0
  def testDecoderFPropDeterministicAttentionDropout(self):
    """Verify that attention dropout is deterministic given fixed seeds."""
    with self.session(use_gpu=False) as sess:
      tf.set_random_seed(8372749040)
      p = self._DecoderParams(
          py_utils.VariationalNoiseParams(None, True, False, seed=1792))

      p.use_while_loop_based_unrolling = False
      p.attention.atten_dropout_prob = 0.5
      p.attention.atten_dropout_deterministic = True

      loss, per_sequence_loss = self._testDecoderFPropHelper(params=p)
      global_step = py_utils.GetGlobalStep()
      tf.global_variables_initializer().run()
      loss_val, per_sequence_loss_val, global_steps_val = sess.run(
          [loss, per_sequence_loss, global_step])

      print('loss = ', loss_val, 'per sequence loss = ', per_sequence_loss_val)
      self.assertAllClose([3.587372, 15.0], loss_val)
      self.assertAllClose([14.171288, 9.965696, 10.221684, 19.451914],
                          per_sequence_loss_val)
      self.assertEqual(0, global_steps_val)

      # Run another step to test global_step and time_step are incremented
      # correctly.
      sess.run(tf.assign_add(global_step, 1))
      loss_val, per_sequence_loss_val, global_steps_val = sess.run(
          [loss, per_sequence_loss, global_step])

      print('loss = ', loss_val, 'per sequence loss = ', per_sequence_loss_val)
      self.assertAllClose([3.626164, 15.0], loss_val)
      self.assertAllClose([14.70993, 10.572938, 10.516836, 18.592758],
                          per_sequence_loss_val)
      self.assertEqual(1, global_steps_val)
Esempio n. 4
0
 def ComputePredictions(self, theta, input_batch):
     input_data = tf.random.normal([1, 10], dtype=tf.float32) + tf.cast(
         input_batch, tf.float32)
     add = tf.assign_add(self.vars.counter1, 1.)
     input_data += add
     result = self.ffn.FProp(theta.ffn, input_data)
     return {'result': result}
Esempio n. 5
0
    def testWeightSpecificSparsity(self):
        param_list = [
            "begin_pruning_step=1", "pruning_frequency=1",
            "end_pruning_step=100", "target_sparsity=0.5",
            "weight_sparsity_map=[layer1:0.6,layer2/weights:0.75,.*kernel:0.6]",
            "threshold_decay=0.0"
        ]
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        with tf.variable_scope("layer1"):
            w1 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights")
            _ = pruning.apply_mask(w1)
        with tf.variable_scope("layer2"):
            w2 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights")
            _ = pruning.apply_mask(w2)
        with tf.variable_scope("layer3"):
            w3 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="kernel")
            _ = pruning.apply_mask(w3)

        p = pruning.Pruning(pruning_hparams)
        mask_update_op = p.conditional_mask_update_op()
        increment_global_step = tf.assign_add(self.global_step, 1)

        with self.cached_session() as session:
            tf.global_variables_initializer().run()
            for _ in range(110):
                session.run(mask_update_op)
                session.run(increment_global_step)

            self.assertAllClose(session.run(pruning.get_weight_sparsity()),
                                [0.6, 0.75, 0.6])
Esempio n. 6
0
 def testConditionalMaskUpdate(self):
     param_list = [
         "pruning_frequency=2", "begin_pruning_step=1",
         "end_pruning_step=6", "nbins=100"
     ]
     test_spec = ",".join(param_list)
     pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
     weights = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights")
     masked_weights = pruning.apply_mask(weights)
     sparsity = tf.Variable(0.00, name="sparsity")
     # Set up pruning
     p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
     p._spec.threshold_decay = 0.0
     mask_update_op = p.conditional_mask_update_op()
     sparsity_val = tf.linspace(0.0, 0.9, 10)
     increment_global_step = tf.assign_add(self.global_step, 1)
     non_zero_count = []
     with self.cached_session() as session:
         tf.global_variables_initializer().run()
         for i in range(10):
             session.run(tf.assign(sparsity, sparsity_val[i]))
             session.run(mask_update_op)
             session.run(increment_global_step)
             non_zero_count.append(np.count_nonzero(masked_weights.eval()))
     # Weights pruned at steps 0,2,4,and,6
     expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
     self.assertAllEqual(expected_non_zero_count, non_zero_count)
Esempio n. 7
0
 def IncBy(self, delta):
     """Increment the counter by delta and return the new value."""
     # NOTE: We must ensure _value is computed (_var + 0) before
     # updating _var with delta.
     delta = tf.cast(delta, tf.int64)
     with tf.control_dependencies([self._value]):
         scalar(self._name, self._value)
         return tf.identity(tf.assign_add(self._var, delta))
Esempio n. 8
0
 def _ApplyAndReset():
   normalized_accums = accums
   if self._apply_crs_to_grad:
     normalized_accums = [
         tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums
     ]
   apply_op = self._opt.apply_gradients(
       list(zip(normalized_accums, variables)))
   with tf.control_dependencies([apply_op]):
     zero_op = [tf.assign(accum, tf.zeros_like(accum)) for accum in accums]
   return tf.group(zero_op, tf.assign_add(global_step, 1))
Esempio n. 9
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        if self._num_micro_batches == 1:
            return self._opt.apply_gradients(grads_and_vars, global_step)
        global_step = global_step or py_utils.GetOrCreateGlobalStepVar()
        with tf.init_scope():
            self._create_slots([v for (_, v) in grads_and_vars])

        accums = []
        variables = []

        for g, v in grads_and_vars:
            accum = self.get_slot(v, 'grad_accum')
            variables.append(v)
            # pytype: disable=attribute-error
            if isinstance(g, tf.IndexedSlices):
                scaled_grad = tf.IndexedSlices(g.values /
                                               self._num_micro_batches,
                                               g.indices,
                                               dense_shape=g.dense_shape)
            else:
                scaled_grad = g / self._num_micro_batches
            accum_tensor = accum.read_value()
            accums.append(accum.assign(accum_tensor + scaled_grad))
            # pytype: enable=attribute-error

        def _ApplyAndReset():
            normalized_accums = accums
            if self._apply_crs_to_grad:
                normalized_accums = [
                    tf.tpu.cross_replica_sum(accum.read_value())
                    for accum in accums
                ]
            apply_op = self._opt.apply_gradients(
                list(zip(normalized_accums, variables)))
            with tf.control_dependencies([apply_op]):
                zero_op = [
                    tf.assign(accum, tf.zeros_like(accum)) for accum in accums
                ]
            return tf.group(zero_op, tf.assign_add(global_step, 1))

        def _Accum():
            return tf.no_op()

        accum_step = tf.cond(
            tf.equal(
                tf.math.floormod(self._counter + 1, self._num_micro_batches),
                0),
            _ApplyAndReset,  # Apply the accumulated gradients and reset.
            _Accum)  # Accumulate gradients.

        with tf.control_dependencies([tf.group(accums)]):
            return tf.group(accum_step, tf.assign_add(self._counter, 1))
Esempio n. 10
0
        def _Acc(vg):
            """Updating accumulators."""

            v, g = vg
            with tf.variable_scope(v.op.name):
                _, a = py_utils.CreateVariable(
                    'grad_accumulator',
                    py_utils.WeightParams(v.get_shape(),
                                          py_utils.WeightInit.Constant(0.0),
                                          self.params.dtype),
                    trainable=False)
                a = tf.assign_add(a, g)

            return py_utils.VarGrad(v, a)
Esempio n. 11
0
        def _Acc(vg):
            """Updating accumulators."""

            v, g = vg
            scope_name = v.name
            if scope_name.endswith(':0'):
                scope_name = scope_name[:-2]
            with tf.variable_scope(scope_name):
                a = py_utils.CreateVariable(
                    'grad_accumulator',
                    py_utils.WeightParams(v.get_shape(),
                                          py_utils.WeightInit.Constant(0.0),
                                          self.params.dtype),
                    trainable=False)
                a = tf.assign_add(a, g)

            return py_utils.VarGrad(v, a)
Esempio n. 12
0
    def testDecoderFPropDeterministicAttentionDropout(self):
        """Verify that attention dropout is deterministic given fixed seeds."""
        with self.session(use_gpu=False):
            tf.random.set_seed(8372749040)
            p = _DecoderParams(
                py_utils.VariationalNoiseParams(None, True, False, seed=1792))

            p.use_while_loop_based_unrolling = False
            p.attention.atten_dropout_prob = 0.5
            p.attention.atten_dropout_deterministic = True

            loss, per_sequence_loss = self._testDecoderFPropHelper(params=p)
            global_step = py_utils.GetGlobalStep()
            self.evaluate(tf.global_variables_initializer())
            loss_val, per_sequence_loss_val, global_steps_val = self.evaluate(
                [loss, per_sequence_loss, global_step])

            print('loss = ', loss_val, 'per sequence loss = ',
                  per_sequence_loss_val)
            self.assertAllClose([3.332992, 15.0], loss_val)
            self.assertAllClose([13.942583, 9.632538, 9.677502, 16.742266],
                                per_sequence_loss_val)
            self.assertEqual(0, global_steps_val)

            # Run another step to test global_step and time_step are incremented
            # correctly.
            self.evaluate(tf.assign_add(global_step, 1))
            loss_val, per_sequence_loss_val, global_steps_val = self.evaluate(
                [loss, per_sequence_loss, global_step])

            print('loss = ', loss_val, 'per sequence loss = ',
                  per_sequence_loss_val)
            self.assertAllClose([3.565631, 15.0], loss_val)
            self.assertAllClose([14.560061, 10.566417, 10.554007, 17.803982],
                                per_sequence_loss_val)
            self.assertEqual(1, global_steps_val)
Esempio n. 13
0
    def testAccumulator(self):
        # testAccumulator compares
        #   - explicit averaging of independently computed var_grads1 and
        #     var_grads2,
        #   - Accumulator(SGD) optimizer effectively doing this over 2 steps.
        np.random.seed(12345)
        np_input1 = np.random.normal(0.1, 0.5, [2, 4, 3])
        np.random.seed(12346)
        np_input2 = np.random.normal(0.1, 0.5, [2, 4, 3])

        with self.session(use_gpu=True, graph=tf.Graph()) as sess:
            tf.random.set_seed(123456)
            params = layers.ProjectionLayer.Params()
            params.name = 'proj'
            params.dtype = tf.float64
            params.input_dim = 3
            params.output_dim = 2
            params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456)

            params.batch_norm = False
            proj_layer = layers.ProjectionLayer(params)
            inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64)
            in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64)
            inputs2 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64)
            in_padding2 = tf.zeros([2, 4, 1], dtype=tf.float64)
            output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1)
            output2 = proj_layer.FPropDefaultTheta(inputs2, in_padding2)
            loss1 = tf.reduce_sum(output1)
            loss2 = tf.reduce_sum(output2)
            var_grads1 = py_utils.ComputeGradients(loss1, proj_layer.vars)
            var_grads2 = py_utils.ComputeGradients(loss2, proj_layer.vars)
            op = optimizer.SGD.Params()
            opt = op.Instantiate()
            lr = 1e-1
            with tf.control_dependencies([loss1, loss2]):
                var_update_op1 = opt.Apply(
                    lr, py_utils.ApplyGradMultiplier(var_grads1, 1. / 2.))
                with tf.control_dependencies([var_update_op1]):
                    var_update_op2 = opt.Apply(
                        lr, py_utils.ApplyGradMultiplier(var_grads2, 1. / 2.))

            self.evaluate(tf.global_variables_initializer())
            vars1 = self.evaluate(proj_layer.vars.Flatten())
            loss1_1, grads1_1, loss1_2, grads1_2 = sess.run(
                [
                    loss1,
                    var_grads1.Transform(tuple), loss2,
                    var_grads2.Transform(tuple)
                ],
                feed_dict={
                    inputs1: np_input1,
                    inputs2: np_input2,
                },
            )
            sess.run([var_update_op2],
                     feed_dict={
                         inputs1: np_input1,
                         inputs2: np_input2,
                     })
            vars1_1 = self.evaluate(proj_layer.vars.Flatten())

        with self.session(use_gpu=True, graph=tf.Graph()) as sess:
            tf.random.set_seed(123456)
            params = layers.ProjectionLayer.Params()
            params.name = 'proj'
            params.dtype = tf.float64
            params.input_dim = 3
            params.output_dim = 2
            params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456)

            params.batch_norm = False
            proj_layer = layers.ProjectionLayer(params)
            in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64)
            inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64)
            output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1)
            loss = tf.reduce_sum(output1)
            var_grads = py_utils.ComputeGradients(loss, proj_layer.vars)
            op = optimizer.Accumulator.Params().Set(
                accum_steps=2,
                dtype=tf.float64,
                optimizer_tpl=optimizer.SGD.Params())
            opt = op.Instantiate()
            lr = 1e-1
            with cluster_factory.ForTestingWorker(add_summary=True):
                var_update_op = opt.Apply(lr, var_grads)
            increment_global_step_op = tf.assign_add(
                py_utils.GetOrCreateGlobalStepVar(), 1)

            self.evaluate(tf.global_variables_initializer())
            vars2 = self.evaluate(proj_layer.vars.Flatten())
            loss2_1, grads2_1 = sess.run(
                [loss, var_grads.Transform(tuple)],
                feed_dict={
                    inputs1: np_input1,
                })
            loss2_2, grads2_2 = sess.run(
                [loss, var_grads.Transform(tuple)],
                feed_dict={
                    inputs1: np_input2,
                })
            acc_0 = self.evaluate([
                v for v in tf.global_variables()
                if 'grad_accumulator' in v.name
            ])[0]
            sess.run([var_update_op], feed_dict={
                inputs1: np_input1,
            })
            acc_1 = self.evaluate([
                v for v in tf.global_variables()
                if 'grad_accumulator' in v.name
            ])[0]
            vars2_intermediate = self.evaluate(proj_layer.vars.Flatten())
            self.evaluate(increment_global_step_op)
            sess.run([var_update_op], feed_dict={
                inputs1: np_input2,
            })
            acc_2 = self.evaluate([
                v for v in tf.global_variables()
                if 'grad_accumulator' in v.name
            ])[0]
            vars2_1 = self.evaluate(proj_layer.vars.Flatten())

            summary = tf.Summary.FromString(
                self.evaluate(tf.summary.merge_all()))
            tf.logging.info(f'summary: {summary}')
            self.assertEqual(summary.value[0].tag, 'sgd_lr')

        self.assertAllClose(vars1, vars2)

        self.assertAllClose(acc_0, np.zeros_like(acc_0))
        self.assertAllClose(acc_1, grads2_1['w'][1])
        self.assertAllClose(acc_2, np.zeros_like(acc_0))

        self.assertAllClose(loss1_1, loss2_1)
        self.assertAllClose(loss1_2, loss2_2)
        self.assertAllClose(grads1_1, grads2_1)
        self.assertAllClose(grads1_2, grads2_2)

        self.assertAllClose(vars1, vars2_intermediate)

        self.assertAllClose(vars2[0], grads2_1['w'][0])
        self.assertAllClose(vars2[0], grads2_2['w'][0])

        self.assertAllClose(
            vars1[0] - 0.5 * lr * (grads1_1['w'][1] + grads1_2['w'][1]),
            vars1_1[0])

        self.assertAllClose(
            vars2[0] - 0.5 * lr * (grads2_1['w'][1] + grads2_2['w'][1]),
            vars2_1[0])

        self.assertAllClose(vars2, vars2_intermediate)
        self.assertAllClose(vars1_1, vars2_1)
Esempio n. 14
0
  def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True):
    """Populates the train_ops dictionary in a backwards pass."""
    metrics = metrics or self._metrics

    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in self._per_input_gradient_mask.items()
      }
    all_losses = []
    for optimization in self.learners:
      learner_name = optimization.params.name
      (losses, train_ops['train/%s' % learner_name],
       eval_metrics) = optimization.Apply(
           metrics,
           vmap,
           gradient_mask=gradient_mask,
           gradient_adjuster=self.AdjustGradients)
      all_losses.extend(losses)
      if add_summary:
        for key, (value, weight) in eval_metrics.items():
          self.AddEvalMetric(key + '/' + learner_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    var_update_ops = [
        tf.group(*tf.nest.flatten(train_ops), name='var_update_ops')
    ]
    # Post training step update.
    with tf.control_dependencies(var_update_ops):
      post_step_op = self.PostTrainingStepUpdate()

    train_ops = {}
    with tf.control_dependencies([post_step_op]):
      # Get the op to update the weight masks and thresholds
      mask_update_op = self._GetMaskUpdateOp()
      train_ops['mask_updates'] = mask_update_op
      with tf.control_dependencies([mask_update_op]):
        true_global_step = py_utils.GetOrCreateGlobalStepVar()
        with tf.ops.colocate_with(true_global_step):
          if self.params.defer_global_step_update:
            increment_global_steps = true_global_step
          else:
            increment_global_steps = tf.assign_add(true_global_step, 1)
        if self._global_step_var != true_global_step:
          with tf.ops.colocate_with(self._global_step_var):
            increment_global_steps = tf.group(
                increment_global_steps, tf.assign_add(self._global_step_var, 1))
        train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    if tf.get_collection(py_utils.TPU_EMBEDDING):
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      sparse_grads = (
          tpu_embedding_gradient.get_gradients_through_dummy_table_variables(
              tpu_embedding))
      tpu_embedding_send_gradient_op = tpu_embedding.generate_send_gradients_op(
          sparse_grads, py_utils.GetGlobalStep())
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

      tpu_embedding_summary_tensors = tf.get_collection(
          py_utils.TPU_EMBEDDING_SUMMARY_TENSORS)
      if add_summary:
        for name, value, weight in tpu_embedding_summary_tensors:
          self.AddEvalMetric(name, value, weight, raise_if_already_added=False)

    for op_name, op in train_ops.items():
      assert op is not None, op_name
    return train_ops
Esempio n. 15
0
  def FProp(self, theta, x, paddings=None, update=False):
    """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
    p = self.params
    if paddings is None:
      paddings = tf.zeros_like(x[:, :, 0, 0])
    # Shape [B, L, 1, 1]
    paddings_4d = paddings[:, :, None, None]

    if p.apply_layer_norm:
      x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

    # 'x' is normalized (but theta.means is not), we use negative dot product to
    # approximate the Euclidean distance here.
    dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means)

    # For padded positions we update the distances to very large numbers.
    very_large_dists = tf.ones_like(dists) * tf.constant(
        0.1, dtype=dists.dtype) * dists.dtype.max
    paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters])
    dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

    # Shape [B, L, N, K], the same as 'dists' above.
    nearest_one_hot = tf.one_hot(
        tf.math.argmin(dists, axis=-1),
        p.num_clusters,
        dtype=py_utils.FPropDtype(p))
    # Same shape as the input 'x'.
    nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                 theta.means)
    diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid))
    diff = py_utils.ApplyPadding(paddings_4d, diff)
    diff = tf.math.reduce_mean(diff, axis=2)

    # The commitment loss which when back proped against encourages the 'x'
    # values to commit to their chosen centroids.
    k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings)
    summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

    # TODO(zhouwk): investigate normalizing theta.means after each update.
    means_norm = tf.norm(theta.means)
    summary_utils.scalar('k_means/centroid_l2_norm/min',
                         tf.math.reduce_min(means_norm))
    summary_utils.scalar('k_means/centroid_l2_norm/mean',
                         tf.math.reduce_mean(means_norm))

    if not update:
      return dists, k_means_loss

    # To update the centroids (self.vars.means), we apply gradient descent on
    # the mini-batch of input 'x', which yields the following:
    #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
    # where x_mean is the average over all the input vectors closest to this
    # centroid.
    #
    # Note that this approach is equivalent with backprop via
    #    loss = tf.math.reduce_mean(
    #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
    # , except that here the learning rate is independently set via 'decay'.

    # Ensure that the padded positions are not used to update the centroids.
    nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

    # Sum away batch and sequence length dimensions to get per cluster count.
    # Shape: [N, K]
    per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
    summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count)

    # Sum of the input 'x' per each closest centroid.
    sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

    if py_utils.use_tpu():
      per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
      sum_x = tf.tpu.cross_replica_sum(sum_x)

    # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
    # cluster's position will always be 0, hence 'sum_x' in that dimension will
    # be 0.
    new_means = sum_x / tf.maximum(
        tf.constant(1.0, dtype=per_cluster_count.dtype),
        tf.expand_dims(per_cluster_count, axis=-1))

    # We use exponential moving average. TODO(zhouwk): investigate smooth this
    # over an exponentially moving averaged per cluster count.
    #
    # Note that we intentionally do not normalize the means after this update
    # as empirically this works better.
    update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means),
                                self.vars.means.dtype)
    return py_utils.with_dependencies(
        [tf.assign_add(self.vars.means, update_means_diff)],
        dists), k_means_loss
Esempio n. 16
0
  def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True):
    """Populates the train_ops dictionary in a backwards pass."""
    metrics = metrics or self._metrics

    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in self._per_input_gradient_mask.items()
      }
    all_losses = []
    for optimization in self.learners:
      learner_name = optimization.params.name
      loss_name = optimization.params.loss_name or learner_name
      metric = metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % learner_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      if add_summary:
        for key, (value, weight) in eval_metrics.items():
          self.AddEvalMetric(key + '/' + learner_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    var_update_ops = [
        tf.group(*tf.nest.flatten(train_ops), name='var_update_ops')
    ]
    # Post training step update.
    with tf.control_dependencies(var_update_ops):
      post_step_op = self.PostTrainingStepUpdate(self.global_step)

    train_ops = {}
    with tf.control_dependencies([post_step_op]):
      # Get the op to update the weight masks and thresholds
      mask_update_op = self._GetMaskUpdateOp()
      train_ops['mask_updates'] = mask_update_op
      with tf.control_dependencies([mask_update_op]):
        true_global_step = py_utils.GetOrCreateGlobalStepVar()
        with tf.ops.colocate_with(true_global_step):
          increment_global_steps = tf.assign_add(true_global_step, 1)
        if self._global_step_var != true_global_step:
          with tf.ops.colocate_with(self._global_step_var):
            increment_global_steps = tf.group(
                increment_global_steps, tf.assign_add(self._global_step_var, 1))
        train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in train_ops.items():
      assert op is not None, op_name
    return train_ops
Esempio n. 17
0
  def _BPropForVariables(self, vmap):
    """Constructs the backward graph."""
    bprop_variable_filters = self.input_generator.GetBpropVariableFilters()
    # Only compute the mask if the variable filters are not empty.
    if bprop_variable_filters != [''] * len(bprop_variable_filters):
      self._ComputeGradientMask(bprop_variable_filters)
    train_ops = {}  # mapping from op name to op.
    gradient_mask = None
    if self._per_input_gradient_mask:
      # TODO(neerajgaur): Change this to use source_selected from input_batch.
      onehot = self.input_generator.GetInputSourceOneHot()
      gradient_mask = {
          k: tf.tensordot(v, onehot, 1)
          for k, v in six.iteritems(self._per_input_gradient_mask)
      }
    all_losses = []
    for optimization in self.learners:
      loss_name = optimization.params.name
      metric = self._metrics.get(loss_name, None)
      if metric is None:
        raise ValueError('Loss %s not found in metrics %s' %
                         (loss_name, list(self._metrics.keys())))
      loss = metric[0]
      all_losses.append(loss)
      train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply(
          loss,
          vmap,
          gradient_mask=gradient_mask,
          gradient_adjuster=self.AdjustGradients)
      for key, (value, weight) in six.iteritems(eval_metrics):
        self.AddEvalMetric(key + '/' + loss_name, value, weight)

    relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates(
        all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES))
    train_ops['bn_updates'] = relevant_bn_updates

    # Get the op to update the weight masks and thresholds
    train_ops['mask_updates'] = self._GetMaskUpdateOp()

    # Post training step update.
    train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step)

    with tf.control_dependencies(tf.nest.flatten(train_ops)):
      true_global_step = py_utils.GetOrCreateGlobalStepVar()
      with tf.colocate_with(true_global_step):
        increment_global_steps = tf.assign_add(true_global_step, 1)
      if self._global_step_var != true_global_step:
        with tf.colocate_with(self._global_step_var):
          increment_global_steps = tf.group(
              increment_global_steps, tf.assign_add(self._global_step_var, 1))
      train_ops['global_step'] = increment_global_steps

    # If we are using Tpu Embeddings, generate the monolithic send
    # gradient op.
    tpu_embedding_activations = tf.get_collection(
        py_utils.TPU_EMBEDDING_ACTIVATIONS)
    if tpu_embedding_activations:
      tpu_embedding_activations_dict = tpu_embedding_activations[0]
      tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0]
      tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients(
          self.loss, tpu_embedding_activations_dict, tpu_embedding)
      train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op

    for op_name, op in six.iteritems(train_ops):
      assert op is not None, op_name

    # TODO(rpang): try to structure _train_op as:
    #   tf.cond(skip_step, <only update skip stats>, <all updates>)
    # so that we skip all other updates when a step is skipped.
    self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')