Exemplo n.º 1
0
 def test_should_calculate_median_class_weights_for_multiple_image_and_multiple_images(
         self):
     with TemporaryDirectory() as path:
         tfrecord_filename = os.path.join(path, 'data.tfrecord')
         get_logger().debug('writing to test tfrecord_filename: %s',
                            tfrecord_filename)
         write_examples_to_tfrecord(tfrecord_filename, [
             dict_to_example(
                 {'image': encode_png([[COLOR_0, COLOR_1, COLOR_2]])}),
             dict_to_example(
                 {'image': encode_png([[COLOR_1, COLOR_2, COLOR_3]])})
         ])
         class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
             [tfrecord_filename], 'image', [COLOR_1, COLOR_2, COLOR_3])
         assert class_weights == [0.25, 0.25, 0.5]
Exemplo n.º 2
0
 def test_should_filter_by_channel_colors(self):
     with patch.object(examples_module,
                       'TFRecordDataset') as TFRecordDataset:
         with tf.Graph().as_default():
             TFRecordDataset.return_value = list_dataset([
                 dict_to_example(
                     extend_dict(
                         EXAMPLE_PROPS_1,
                         page_no=page_no,
                         annotation_image=image_with_color(
                             some_color(page_no)))).SerializeToString()
                 for page_no in [1, 2, 3, 4]
             ], tf.string)
             examples = read_examples(
                 DATA_PATH,
                 shuffle=False,
                 num_epochs=1,
                 page_range=(0, 100),
                 channel_colors=[some_color(i) for i in [2, 3]])
             TFRecordDataset.assert_called_with(DATA_PATH,
                                                compression_type='GZIP')
             with tf.Session() as session:
                 assert [
                     x['page_no']
                     for x in fetch_examples(session, examples)
                 ] == [2, 3]
Exemplo n.º 3
0
 def test_should_return_zero_for_non_occuring_class(self):
     with TemporaryDirectory() as path:
         tfrecord_filename = os.path.join(path, 'data.tfrecord')
         get_logger().debug('writing to test tfrecord_filename: %s',
                            tfrecord_filename)
         write_examples_to_tfrecord(
             tfrecord_filename,
             [dict_to_example({'image': encode_png([[COLOR_1]])})])
         class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
             [tfrecord_filename], 'image', [COLOR_1, COLOR_2])
         assert class_weights == [1.0, 0.0]
 def expand(self, pcoll):  # pylint: disable=W0221
     return (
         pcoll |
         'ConvertToTfExamples' >> beam.FlatMap(lambda v: (
             dict_to_example(props)
             for props in self.extract_props(v)
         )) |
         'SerializeToString' >> beam.Map(lambda x: x.SerializeToString()) |
         'SaveToTfRecords' >> beam.io.WriteToTFRecord(
             self.file_path,
             file_name_suffix='.tfrecord.gz'
         )
     )
Exemplo n.º 5
0
 def test_should_use_color_map_keys_as_channels_by_default(self):
     with TemporaryDirectory() as path:
         tfrecord_filename = os.path.join(path, 'data.tfrecord')
         get_logger().debug('writing to test tfrecord_filename: %s',
                            tfrecord_filename)
         write_examples_to_tfrecord(
             tfrecord_filename,
             [dict_to_example({'image': encode_png([[COLOR_1, COLOR_2]])})])
         class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
             [tfrecord_filename], 'image', {
                 'color1': COLOR_1,
                 'color2': COLOR_2
             })
         assert set(class_weights_map.keys()) == {'color1', 'color2'}
Exemplo n.º 6
0
 def test_should_calculate_median_class_weights_for_single_image_and_single_color(
         self):
     with TemporaryDirectory() as path:
         tfrecord_filename = os.path.join(path, 'data.tfrecord')
         get_logger().debug('writing to test tfrecord_filename: %s',
                            tfrecord_filename)
         write_examples_to_tfrecord(
             tfrecord_filename,
             [dict_to_example({'image': encode_png([[COLOR_1, COLOR_2]])})])
         class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
             [tfrecord_filename],
             'image', {
                 'color1': COLOR_1,
                 'color2': COLOR_2,
                 'color3': COLOR_3
             },
             channels=['color1', 'color2'])
         assert class_weights_map == {'color1': 0.5, 'color2': 0.5}
Exemplo n.º 7
0
 def test_should_include_unknown_class_if_enabled(self):
     with TemporaryDirectory() as path:
         tfrecord_filename = os.path.join(path, 'data.tfrecord')
         get_logger().debug('writing to test tfrecord_filename: %s',
                            tfrecord_filename)
         write_examples_to_tfrecord(tfrecord_filename, [
             dict_to_example({
                 'image':
                 encode_png([[COLOR_0, COLOR_1, COLOR_2, COLOR_3]])
             })
         ])
         class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
             [tfrecord_filename],
             'image', {
                 'color1': COLOR_1,
                 'color2': COLOR_2
             },
             use_unknown_class=True,
             unknown_class_label='unknown')
         assert set(
             class_weights_map.keys()) == {'color1', 'color2', 'unknown'}
Exemplo n.º 8
0
def dict_to_example_and_reverse(props):
    return list(
        iter_examples_to_dict_list(
            [dict_to_example(props).SerializeToString()]))[0]
Exemplo n.º 9
0
def example_dataset(map_keys_tracker, examples):
    dataset = list_dataset(
        [dict_to_example(example).SerializeToString() for example in examples],
        tf.string)
    dataset = dataset.map(map_keys_tracker.wrap(parse_example))
    return dataset
Exemplo n.º 10
0
import sciencebeam_gym.trainer.data.examples as examples_module
from sciencebeam_gym.trainer.data.examples import (read_examples, tf_data)

DATA_PATH = '.temp/data/*.tfrecord'

IMAGE_SHAPE = (5, 5)

EXAMPLE_PROPS_1 = {
    'input_uri': 'input.png',
    'input_image': b'input image',
    'annotation_uri': 'annotation.png',
    'annotation_image': b'annotation image'
}

RECORD_1 = dict_to_example(EXAMPLE_PROPS_1).SerializeToString()


def get_logger():
    return logging.getLogger(__name__)


def setup_module():
    logging.basicConfig(level='DEBUG')


def list_dataset(data, dtype):
    data = tf.constant(data, dtype=dtype)
    return tf_data.Dataset.from_tensor_slices(data)