示例#1
0
    def ComputeLoss(self, theta, predictions, input_batch):
        p = self.params
        per_example = {}
        with tf.name_scope('groundtruth_loss'):
            student_groundtruth_loss, student_groundtruth_per_example = (
                self.student.ComputeLoss(theta.student, predictions.student,
                                         input_batch))
            groundtruth_loss = student_groundtruth_loss
            groundtruth_loss['student_groundtruth_loss'] = (
                student_groundtruth_loss['loss'])
            per_example.update(student_groundtruth_per_example)

            if p.train_teacher:
                teacher_groundtruth_loss, _ = self.teacher.ComputeLoss(
                    theta.teacher, predictions.teacher, input_batch)
                groundtruth_loss['teacher_groundtruth_loss'] = (
                    teacher_groundtruth_loss['loss'])
                # The new loss is the wighted sum of the teacher and student losses.
                groundtruth_loss['loss'] = py_utils.WeightedAvg(
                    *zip(teacher_groundtruth_loss['loss'],
                         student_groundtruth_loss['loss']))

        with tf.name_scope('distillation_loss'):
            distillation_loss, distill_per_example = self.ComputeDistillationLoss(
                theta, predictions, input_batch)
            distillation_loss['distillation_loss'] = distillation_loss['loss']
            per_example.update(distill_per_example)

        distillation_loss_weight = self.distillation_loss_weight.Value()
        metrics = py_utils.CombineMetrics([
            (groundtruth_loss, 1 - distillation_loss_weight),
            (distillation_loss, distillation_loss_weight),
        ])
        return metrics, per_example
示例#2
0
 def testWeightedAvg(self):
   with self.session(use_gpu=False) as sess:
     losses = tf.constant([5.6, 4.6, 1.5, 3.4])
     weights = tf.constant([10, 9, 2, 8])
     loss, weight = py_utils.WeightedAvg(losses, weights)
     expected = [4.4, 29]
     actual = sess.run([loss, weight])
     self.assertAllClose(actual, expected)
示例#3
0
    def FinalizeMetrics(self, loop_result):
        """Compute final average of the metrics, given loop_result tensors.

    To be called outside the training loop body , but still in the scope of
    tpu.batch_parallel.

    Args:
      loop_result: Result of the training loop.

    Returns:
      The tensors of the final avg values and total weights.
    """
        # Each metric has two tensors in the loop carrying result.
        metrics = loop_result[:2 * len(self._metrics.items())]
        # Aggregate across tpu replicas.
        metrics = [tf.tpu.cross_replica_sum(x) for x in metrics]
        ret = []
        for (value, weight) in self._Zip(metrics):
            value, weight = py_utils.WeightedAvg(value / weight, weight)
            ret += [value, weight]
        return ret