def ComputeLoss(self, theta, predictions, input_batch): p = self.params per_example = {} with tf.name_scope('groundtruth_loss'): student_groundtruth_loss, student_groundtruth_per_example = ( self.student.ComputeLoss(theta.student, predictions.student, input_batch)) groundtruth_loss = student_groundtruth_loss groundtruth_loss['student_groundtruth_loss'] = ( student_groundtruth_loss['loss']) per_example.update(student_groundtruth_per_example) if p.train_teacher: teacher_groundtruth_loss, _ = self.teacher.ComputeLoss( theta.teacher, predictions.teacher, input_batch) groundtruth_loss['teacher_groundtruth_loss'] = ( teacher_groundtruth_loss['loss']) # The new loss is the wighted sum of the teacher and student losses. groundtruth_loss['loss'] = py_utils.WeightedAvg( *zip(teacher_groundtruth_loss['loss'], student_groundtruth_loss['loss'])) with tf.name_scope('distillation_loss'): distillation_loss, distill_per_example = self.ComputeDistillationLoss( theta, predictions, input_batch) distillation_loss['distillation_loss'] = distillation_loss['loss'] per_example.update(distill_per_example) distillation_loss_weight = self.distillation_loss_weight.Value() metrics = py_utils.CombineMetrics([ (groundtruth_loss, 1 - distillation_loss_weight), (distillation_loss, distillation_loss_weight), ]) return metrics, per_example
def testWeightedAvg(self): with self.session(use_gpu=False) as sess: losses = tf.constant([5.6, 4.6, 1.5, 3.4]) weights = tf.constant([10, 9, 2, 8]) loss, weight = py_utils.WeightedAvg(losses, weights) expected = [4.4, 29] actual = sess.run([loss, weight]) self.assertAllClose(actual, expected)
def FinalizeMetrics(self, loop_result): """Compute final average of the metrics, given loop_result tensors. To be called outside the training loop body , but still in the scope of tpu.batch_parallel. Args: loop_result: Result of the training loop. Returns: The tensors of the final avg values and total weights. """ # Each metric has two tensors in the loop carrying result. metrics = loop_result[:2 * len(self._metrics.items())] # Aggregate across tpu replicas. metrics = [tf.tpu.cross_replica_sum(x) for x in metrics] ret = [] for (value, weight) in self._Zip(metrics): value, weight = py_utils.WeightedAvg(value / weight, weight) ret += [value, weight] return ret