def test_should_calculate_median_class_weights_for_multiple_image_and_multiple_images( self): with TemporaryDirectory() as path: tfrecord_filename = os.path.join(path, 'data.tfrecord') get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) write_examples_to_tfrecord(tfrecord_filename, [ dict_to_example( {'image': encode_png([[COLOR_0, COLOR_1, COLOR_2]])}), dict_to_example( {'image': encode_png([[COLOR_1, COLOR_2, COLOR_3]])}) ]) class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( [tfrecord_filename], 'image', [COLOR_1, COLOR_2, COLOR_3]) assert class_weights == [0.25, 0.25, 0.5]
def test_should_filter_by_channel_colors(self): with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset: with tf.Graph().as_default(): TFRecordDataset.return_value = list_dataset([ dict_to_example( extend_dict( EXAMPLE_PROPS_1, page_no=page_no, annotation_image=image_with_color( some_color(page_no)))).SerializeToString() for page_no in [1, 2, 3, 4] ], tf.string) examples = read_examples( DATA_PATH, shuffle=False, num_epochs=1, page_range=(0, 100), channel_colors=[some_color(i) for i in [2, 3]]) TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP') with tf.Session() as session: assert [ x['page_no'] for x in fetch_examples(session, examples) ] == [2, 3]
def test_should_return_zero_for_non_occuring_class(self): with TemporaryDirectory() as path: tfrecord_filename = os.path.join(path, 'data.tfrecord') get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) write_examples_to_tfrecord( tfrecord_filename, [dict_to_example({'image': encode_png([[COLOR_1]])})]) class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( [tfrecord_filename], 'image', [COLOR_1, COLOR_2]) assert class_weights == [1.0, 0.0]
def expand(self, pcoll): # pylint: disable=W0221 return ( pcoll | 'ConvertToTfExamples' >> beam.FlatMap(lambda v: ( dict_to_example(props) for props in self.extract_props(v) )) | 'SerializeToString' >> beam.Map(lambda x: x.SerializeToString()) | 'SaveToTfRecords' >> beam.io.WriteToTFRecord( self.file_path, file_name_suffix='.tfrecord.gz' ) )
def test_should_use_color_map_keys_as_channels_by_default(self): with TemporaryDirectory() as path: tfrecord_filename = os.path.join(path, 'data.tfrecord') get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) write_examples_to_tfrecord( tfrecord_filename, [dict_to_example({'image': encode_png([[COLOR_1, COLOR_2]])})]) class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( [tfrecord_filename], 'image', { 'color1': COLOR_1, 'color2': COLOR_2 }) assert set(class_weights_map.keys()) == {'color1', 'color2'}
def test_should_calculate_median_class_weights_for_single_image_and_single_color( self): with TemporaryDirectory() as path: tfrecord_filename = os.path.join(path, 'data.tfrecord') get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) write_examples_to_tfrecord( tfrecord_filename, [dict_to_example({'image': encode_png([[COLOR_1, COLOR_2]])})]) class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( [tfrecord_filename], 'image', { 'color1': COLOR_1, 'color2': COLOR_2, 'color3': COLOR_3 }, channels=['color1', 'color2']) assert class_weights_map == {'color1': 0.5, 'color2': 0.5}
def test_should_include_unknown_class_if_enabled(self): with TemporaryDirectory() as path: tfrecord_filename = os.path.join(path, 'data.tfrecord') get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) write_examples_to_tfrecord(tfrecord_filename, [ dict_to_example({ 'image': encode_png([[COLOR_0, COLOR_1, COLOR_2, COLOR_3]]) }) ]) class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( [tfrecord_filename], 'image', { 'color1': COLOR_1, 'color2': COLOR_2 }, use_unknown_class=True, unknown_class_label='unknown') assert set( class_weights_map.keys()) == {'color1', 'color2', 'unknown'}
def dict_to_example_and_reverse(props): return list( iter_examples_to_dict_list( [dict_to_example(props).SerializeToString()]))[0]
def example_dataset(map_keys_tracker, examples): dataset = list_dataset( [dict_to_example(example).SerializeToString() for example in examples], tf.string) dataset = dataset.map(map_keys_tracker.wrap(parse_example)) return dataset
import sciencebeam_gym.trainer.data.examples as examples_module from sciencebeam_gym.trainer.data.examples import (read_examples, tf_data) DATA_PATH = '.temp/data/*.tfrecord' IMAGE_SHAPE = (5, 5) EXAMPLE_PROPS_1 = { 'input_uri': 'input.png', 'input_image': b'input image', 'annotation_uri': 'annotation.png', 'annotation_image': b'annotation image' } RECORD_1 = dict_to_example(EXAMPLE_PROPS_1).SerializeToString() def get_logger(): return logging.getLogger(__name__) def setup_module(): logging.basicConfig(level='DEBUG') def list_dataset(data, dtype): data = tf.constant(data, dtype=dtype) return tf_data.Dataset.from_tensor_slices(data)