def test_should_return_higher_value_for_less_frequent_occuring_class(self): with tf.Graph().as_default(): with tf.Session(): frequencies = [2, 1] result = tf_calculate_efnet_weights_for_frequency_by_label( frequencies).eval() assert result[0] < result[1]
def test_should_return_zero_value_for_not_occuring_class(self): with tf.Graph().as_default(): with tf.Session(): frequencies = [1, 0] result = tf_calculate_efnet_weights_for_frequency_by_label( frequencies).eval() assert result[-1] == 0.0
def test_should_return_same_value_for_classes_with_same_frequencies(self): with tf.Graph().as_default(): with tf.Session(): frequencies = [1, 1] result = tf_calculate_efnet_weights_for_frequency_by_label( frequencies).eval() assert result[0] == result[1]
def _create_pos_weights_tensor( base_loss, separate_channel_annotation_tensor, pos_weight_values, input_uri, debug): frequency_by_label = tf.reduce_sum( separate_channel_annotation_tensor, axis=[0, 1], keep_dims=True, name='frequency_by_channel' ) pos_weight_sample = tf_calculate_efnet_weights_for_frequency_by_label( frequency_by_label ) pos_weight = ( pos_weight_sample * pos_weight_values if base_loss == BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY else pos_weight_sample ) if debug: pos_weight = tf.Print( pos_weight, [ pos_weight, pos_weight_sample, frequency_by_label, input_uri ], 'pos weights, sample, frequency, uri: ', summarize=1000 ) get_logger().debug( 'pos_weight before batch: %s (frequency_by_label: %s)', pos_weight, frequency_by_label ) return pos_weight