def setUp(self): super(SetupLossesTest, self).setUp() is_train = True gitapp = controller.GetInputTargetAndPredictedParameters( self.dp, self.ap, 110, self.stride, self.stitch_patch_size, self.bp, self.core_model, self.add_head, self.shuffle, self.num_classes, util.softmax_cross_entropy, is_train) (self.input_loss_lts, self.target_loss_lts) = controller.setup_losses(gitapp)
def total_loss( gitapp: controller.GetInputTargetAndPredictedParameters, ) -> Tuple[tf.Tensor, Dict[str, lt.LabeledTensor], Dict[str, lt.LabeledTensor]]: """Get the total weighted training loss.""" input_loss_lts, target_loss_lts = controller.setup_losses(gitapp) def mean(lts: Dict[str, lt.LabeledTensor]) -> tf.Tensor: sum_op = tf.add_n([t.tensor for t in lts.values()]) return sum_op / float(len(lts)) # Give the input loss the same weight as the target loss. input_weight = 0.5 total_loss_op = input_weight * mean(input_loss_lts) + ( 1 - input_weight) * mean(target_loss_lts) tf.summary.scalar('total_loss', total_loss_op) return total_loss_op, input_loss_lts, target_loss_lts