def _build_loss(self): logits = super().get_logits() labels = super().get_labels() num_classes = super().get_num_classes() num_positives = None if self._normalize_by_positives: positives = tf.cast(tf.not_equal(labels, 0), tf.float32) # [BATCH_SIZE, ...] positives_dim_n = len(positives.get_shape()) axis = list(range(1, positives_dim_n)) num_positives = tf.reduce_sum( positives, axis=axis) # [BATCH_SIZE, N_POSITIVES] focal_loss = Loss.focal_loss(logits=logits, labels=labels, num_classes=num_classes, num_positives=num_positives, focal_gamma=self._focal_gamma) if not self._normalize_by_positives: focal_loss = focal_loss / float(super().get_batch_size()) super().track_loss(focal_loss, FocalTrainer.FOCAL_LOSS) return focal_loss
def _build_loss(self): logits = super().get_logits() labels = super().get_labels() num_classes = super().get_num_classes() positives = tf.not_equal(labels, 0) # [BATCH_SIZE, ...] positives_dim_n = len(positives.get_shape()) axis = list(range(1, positives_dim_n)) num_positives = tf.reduce_sum(positives, axis=axis) # [BATCH_SIZE, N_POSITIVES] focal_loss = Loss.focal_loss(logits=logits, labels=labels, num_classes=num_classes, num_positives=num_positives, focal_gamma=self._focal_gamma, raw_tensor=True) weights = super().get_weight_map() focal_loss = tf.reduce_sum(focal_loss * weights) super().track_loss(focal_loss, WeightedFocalTrainer.FOCAL_LOSS) return focal_loss