Exemplo n.º 1
0
    def ComputeLoss(self, theta, predictions, input_batch):
        p = self.params
        per_example = {}
        with tf.name_scope('groundtruth_loss'):
            student_groundtruth_loss, student_groundtruth_per_example = (
                self.student.ComputeLoss(theta.student, predictions.student,
                                         input_batch))
            groundtruth_loss = student_groundtruth_loss
            groundtruth_loss['student_groundtruth_loss'] = (
                student_groundtruth_loss['loss'])
            per_example.update(student_groundtruth_per_example)

            if p.train_teacher:
                teacher_groundtruth_loss, _ = self.teacher.ComputeLoss(
                    theta.teacher, predictions.teacher, input_batch)
                groundtruth_loss['teacher_groundtruth_loss'] = (
                    teacher_groundtruth_loss['loss'])
                # The new loss is the wighted sum of the teacher and student losses.
                groundtruth_loss['loss'] = py_utils.WeightedAvg(
                    *zip(teacher_groundtruth_loss['loss'],
                         student_groundtruth_loss['loss']))

        with tf.name_scope('distillation_loss'):
            distillation_loss, distill_per_example = self.ComputeDistillationLoss(
                theta, predictions, input_batch)
            distillation_loss['distillation_loss'] = distillation_loss['loss']
            per_example.update(distill_per_example)

        distillation_loss_weight = self.distillation_loss_weight.FProp(
            theta.distillation_loss_weight, self.global_step)
        metrics = py_utils.CombineMetrics([
            (groundtruth_loss, 1 - distillation_loss_weight),
            (distillation_loss, distillation_loss_weight),
        ])
        return metrics, per_example
Exemplo n.º 2
0
 def testCombineMetricsKeyNotInAllMetrics(self):
   a = py_utils.NestedMap()
   a['a'] = (1, 1)
   b = py_utils.NestedMap()
   b['b'] = (2, 2)
   b['loss'] = (50, 20)
   c = py_utils.NestedMap()
   c['loss'] = (60, 15)
   with self.assertRaises(ValueError):
     py_utils.CombineMetrics([(a, 0.7), (b, 0.3), (c, 1.5)])
Exemplo n.º 3
0
 def testCombineMetrics(self):
   a = py_utils.NestedMap()
   a['a'] = (1, 1)
   a['loss'] = (100, 10)
   b = py_utils.NestedMap()
   b['b'] = (2, 2)
   b['loss'] = (50, 20)
   c = py_utils.NestedMap()
   c['loss'] = (60, 15)
   combined = py_utils.CombineMetrics([(a, 0.7), (b, 0.3), (c, 1.5)])
   self.assertEqual(combined['a'], (1, 1))
   self.assertEqual(combined['b'], (2, 2))
   total_loss = combined['loss'][0] * combined['loss'][1]
   self.assertEqual(total_loss, 100 * 10 * 0.7 + 50 * 20 * 0.3 + 60 * 15 * 1.5)
Exemplo n.º 4
0
  def ComputeLoss(self, theta, input_batch, predictions):
    with tf.name_scope('groundtruth_loss'):
      groundtruth_loss = self.student.ComputeLoss(theta.student, input_batch,
                                                  predictions.student)
      groundtruth_loss['groundtruth_loss'] = groundtruth_loss['loss']

    with tf.name_scope('distillation_loss'):
      distillation_loss = self.ComputeDistillationLoss(theta, input_batch,
                                                       predictions)
      distillation_loss['distillation_loss'] = distillation_loss['loss']

    distillation_loss_weight = self.distillation_loss_weight.FProp(
        theta.distillation_loss_weight, self._global_step)
    metrics = py_utils.CombineMetrics([
        (groundtruth_loss, 1 - distillation_loss_weight),
        (distillation_loss, distillation_loss_weight),
    ])
    return metrics
Exemplo n.º 5
0
  def ComputeLoss(self, theta, predictions, input_batch):
    per_example = {}
    with tf.name_scope('groundtruth_loss'):
      groundtruth_loss, groundtruth_per_example = self.student.ComputeLoss(
          theta.student, predictions.student, input_batch)
      groundtruth_loss['groundtruth_loss'] = groundtruth_loss['loss']
      per_example.update(groundtruth_per_example)

    with tf.name_scope('distillation_loss'):
      distillation_loss, distill_per_example = self.ComputeDistillationLoss(
          theta, predictions, input_batch)
      distillation_loss['distillation_loss'] = distillation_loss['loss']
      per_example.update(distill_per_example)

    distillation_loss_weight = self.distillation_loss_weight.FProp(
        theta.distillation_loss_weight, self.global_step)
    metrics = py_utils.CombineMetrics([
        (groundtruth_loss, 1 - distillation_loss_weight),
        (distillation_loss, distillation_loss_weight),
    ])
    return metrics, per_example