Example #1
0
 def test_should_return_higher_value_for_less_frequent_occuring_class(self):
     with tf.Graph().as_default():
         with tf.Session():
             frequencies = [2, 1]
             result = tf_calculate_efnet_weights_for_frequency_by_label(
                 frequencies).eval()
             assert result[0] < result[1]
Example #2
0
 def test_should_return_zero_value_for_not_occuring_class(self):
     with tf.Graph().as_default():
         with tf.Session():
             frequencies = [1, 0]
             result = tf_calculate_efnet_weights_for_frequency_by_label(
                 frequencies).eval()
             assert result[-1] == 0.0
Example #3
0
 def test_should_return_same_value_for_classes_with_same_frequencies(self):
     with tf.Graph().as_default():
         with tf.Session():
             frequencies = [1, 1]
             result = tf_calculate_efnet_weights_for_frequency_by_label(
                 frequencies).eval()
             assert result[0] == result[1]
Example #4
0
def _create_pos_weights_tensor(
        base_loss,
        separate_channel_annotation_tensor,
        pos_weight_values,
        input_uri,
        debug):

    frequency_by_label = tf.reduce_sum(
        separate_channel_annotation_tensor,
        axis=[0, 1],
        keep_dims=True,
        name='frequency_by_channel'
    )
    pos_weight_sample = tf_calculate_efnet_weights_for_frequency_by_label(
        frequency_by_label
    )
    pos_weight = (
        pos_weight_sample * pos_weight_values
        if base_loss == BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
        else pos_weight_sample
    )
    if debug:
        pos_weight = tf.Print(
            pos_weight, [
                pos_weight,
                pos_weight_sample,
                frequency_by_label,
                input_uri
            ],
            'pos weights, sample, frequency, uri: ',
            summarize=1000
        )
    get_logger().debug(
        'pos_weight before batch: %s (frequency_by_label: %s)',
        pos_weight, frequency_by_label
    )
    return pos_weight