def test_get_configs_from_multiple_files(self): """Tests that proto configs can be read from multiple files.""" temp_dir = self.get_temp_dir() # Write model config file. model_config_path = os.path.join(temp_dir, "model.config") model = model_pb2.DetectionModel() model.faster_rcnn.num_classes = 10 _write_config(model, model_config_path) # Write train config file. train_config_path = os.path.join(temp_dir, "train.config") train_config = train_config = train_pb2.TrainConfig() train_config.batch_size = 32 _write_config(train_config, train_config_path) # Write train input config file. train_input_config_path = os.path.join(temp_dir, "train_input.config") train_input_config = input_reader_pb2.InputReader() train_input_config.label_map_path = "path/to/label_map" _write_config(train_input_config, train_input_config_path) # Write eval config file. eval_config_path = os.path.join(temp_dir, "eval.config") eval_config = eval_pb2.EvalConfig() eval_config.num_examples = 20 _write_config(eval_config, eval_config_path) # Write eval input config file. eval_input_config_path = os.path.join(temp_dir, "eval_input.config") eval_input_config = input_reader_pb2.InputReader() eval_input_config.label_map_path = "path/to/another/label_map" _write_config(eval_input_config, eval_input_config_path) configs = config_util.get_configs_from_multiple_files( model_config_path=model_config_path, train_config_path=train_config_path, train_input_config_path=train_input_config_path, eval_config_path=eval_config_path, eval_input_config_path=eval_input_config_path) self.assertProtoEquals(model, configs["model"]) self.assertProtoEquals(train_config, configs["train_config"]) self.assertProtoEquals(train_input_config, configs["train_input_config"]) self.assertProtoEquals(eval_config, configs["eval_config"]) self.assertProtoEquals(eval_input_config, configs["eval_input_configs"][0])
def test_raises_error_with_no_input_paths(self): input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true """ input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) with self.assertRaises(ValueError): input_reader_builder.build(input_reader_proto)
def test_reduce_num_reader(self): config = input_reader_pb2.InputReader() config.num_readers = 10 config.shuffle = False data = self._get_dataset_next([self._path_template % '*'], config, batch_size=20) with self.test_session() as sess: self.assertAllEqual(sess.run(data), [[1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5, 50]])
def test_disable_shuffle_(self): config = input_reader_pb2.InputReader() config.num_readers = 1 config.shuffle = False data = self._get_dataset_next( [self._shuffle_path_template % '*'], config, batch_size=10) expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] with self.test_session() as sess: self.assertAllEqual(sess.run(data), [expected_non_shuffle_output])
def test_read_dataset_single_epoch(self): config = input_reader_pb2.InputReader() config.num_epochs = 1 config.num_readers = 1 config.shuffle = False data = self._get_dataset_next( [self._path_template % '0'], config, batch_size=30) with self.test_session() as sess: # First batch will retrieve as much as it can, second batch will fail. self.assertAllEqual(sess.run(data), [[1, 10]]) self.assertRaises(tf.errors.OutOfRangeError, sess.run, data)
def test_enable_shuffle(self): config = input_reader_pb2.InputReader() config.num_readers = 1 config.shuffle = True tf.set_random_seed(1) # Set graph level seed. data = self._get_dataset_next( [self._shuffle_path_template % '*'], config, batch_size=10) expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] with self.test_session() as sess: self.assertTrue( np.any(np.not_equal(sess.run(data), expected_non_shuffle_output)))
def get_input_reader_config(tf_record_path): input_reader_text_proto = """ shuffle: false num_readers: 1 tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) return input_reader_proto
def test_disable_shuffle_(self): config = input_reader_pb2.InputReader() config.num_readers = 1 config.shuffle = False def graph_fn(): return self._get_dataset_next( [self._shuffle_path_template % '*'], config, batch_size=10) expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] # Note that the execute function extracts single outputs if the return # value is of size 1. data = self.execute(graph_fn, []) self.assertAllEqual(data, expected_non_shuffle_output)
def test_raises_error_without_input_paths(self): input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true """ input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) configs = self._get_model_configs_from_proto() with self.assertRaises(ValueError): _ = seq_dataset_builder.build(input_reader_proto, configs['model'], configs['lstm_model'], unroll_length=1)
def test_video_input_reader(self, video_input_type): input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(self._get_input_proto(video_input_type), input_reader_proto) configs = self._get_model_configs_from_proto() tensor_dict = seq_dataset_builder.build(input_reader_proto, configs['model'], configs['lstm_model'], unroll_length=1) all_dict = self._create_training_dict(tensor_dict) self.assertEqual((1, 32, 32, 3), all_dict['image0'].shape) self.assertEqual(4, all_dict['groundtruth_boxes0'].shape[1])
def get_dataset(tfrecord_path, label_map='label_map.pbtxt'): """ Opens a tf record file and create tf dataset args: - tfrecord_path [str]: path to a tf record file - label_map [str]: path the label_map file returns: - dataset [tf.Dataset]: tensorflow dataset """ input_config = input_reader_pb2.InputReader() input_config.label_map_path = label_map input_config.tf_record_input_reader.input_path[:] = [tfrecord_path] dataset = build_dataset(input_config) return dataset
def test_enable_shuffle(self): config = input_reader_pb2.InputReader() config.num_readers = 1 config.shuffle = True tf.set_random_seed(1) # Set graph level seed. def graph_fn(): return self._get_dataset_next( [self._shuffle_path_template % '*'], config, batch_size=10) expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] data = self.execute(graph_fn, []) self.assertTrue( np.any(np.not_equal(data, expected_non_shuffle_output)))
def get_configs_from_multiple_files(): train_config = train_pb2.TrainConfig() with tf.gfile.GFile(FLAGS.train_config_path, 'r') as f: text_format.Merge(f.read(), train_config) model_config = model_pb2.DetectionModel() with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f: text_format.Merge(f.read(), model_config) input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f: text_format.Merge(f.read(), input_config) return model_config, train_config, input_config
def test_build_tf_record_input_reader_and_load_instance_masks(self): input_reader_text_proto = """ load_instance_masks: true tf_record_input_reader {} """ input_reader_proto = input_reader_pb2.InputReader() text_format.Parse(input_reader_text_proto, input_reader_proto) decoder = decoder_builder.build(input_reader_proto) tensor_dict = decoder.decode(self._make_serialized_tf_example()) with tf.train.MonitoredSession() as sess: output_dict = sess.run(tensor_dict) self.assertAllEqual((1, 4, 5), output_dict[ fields.InputDataFields.groundtruth_instance_masks].shape)
def test_build_tf_record_input_reader_sequence_example(self): tf_record_path = self.create_tf_record_sequence_example() input_reader_text_proto = """ shuffle: false num_readers: 1 input_type: TF_SEQUENCE_EXAMPLE tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() input_reader_proto.label_map_path = _get_labelmap_path() text_format.Merge(input_reader_text_proto, input_reader_proto) tensor_dict = input_reader_builder.build(input_reader_proto) with tf.train.MonitoredSession() as sess: output_dict = sess.run(tensor_dict) expected_groundtruth_classes = [[-1, -1], [1, -1], [1, 2], [-1, -1]] expected_groundtruth_boxes = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0], [0.1, 0.1, 0.2, 0.2]], [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] expected_num_groundtruth_boxes = [0, 1, 2, 0] self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks, output_dict) # sequence example images are encoded self.assertEqual((4, ), output_dict[fields.InputDataFields.image].shape) self.assertAllEqual( expected_groundtruth_classes, output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEqual( (4, 2, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertAllClose( expected_groundtruth_boxes, output_dict[fields.InputDataFields.groundtruth_boxes]) self.assertAllClose( expected_num_groundtruth_boxes, output_dict[fields.InputDataFields.num_groundtruth_boxes])
def test_build_tf_record_input_reader_with_batch_size_two(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) def one_hot_class_encoding_fn(tensor_dict): tensor_dict[ fields.InputDataFields.groundtruth_classes] = tf.one_hot( tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3) return tensor_dict tensor_dict = dataset_util.make_initializable_iterator( dataset_builder.build( input_reader_proto, transform_input_data_fn=one_hot_class_encoding_fn, batch_size=2, max_num_boxes=2, num_classes=3, spatial_image_shape=[4, 5])).get_next() sv = tf.train.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) output_dict = sess.run(tensor_dict) self.assertAllEqual([2, 4, 5, 3], output_dict[fields.InputDataFields.image].shape) self.assertAllEqual( [2, 2, 3], output_dict[fields.InputDataFields.groundtruth_classes].shape) self.assertAllEqual( [2, 2, 4], output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertAllEqual( [[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]], output_dict[fields.InputDataFields.groundtruth_boxes])
def test_reduce_num_reader(self): config = input_reader_pb2.InputReader() config.num_readers = 10 config.shuffle = False def graph_fn(): return self._get_dataset_next([self._path_template % '*'], config, batch_size=20) data = self.execute(graph_fn, []) # Note that the execute function extracts single outputs if the return # value is of size 1. self.assertCountEqual(data, [ 1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5, 50 ])
def test_with_input_context(self): """Test that a subset is read with input context given.""" tf_record_path = self.create_tf_record(num_examples_per_shard=16, num_shards=2) input_reader_text_proto = """ shuffle: false num_readers: 1 num_epochs: 1 tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) input_context = tf.distribute.InputContext(num_input_pipelines=2, input_pipeline_id=0, num_replicas_in_sync=4) for i in range(8): # pylint:disable=cell-var-from-loop def graph_fn(): dataset = dataset_builder.build(input_reader_proto, batch_size=8, input_context=input_context) dataset = dataset.skip(i) return get_iterator_next_for_testing(dataset, self.is_tf2()) batch = self.execute(graph_fn, []) self.assertEqual(batch['image'].shape, (2, 4, 5, 3)) def graph_fn_last_batch(): dataset = dataset_builder.build(input_reader_proto, batch_size=8, input_context=input_context) dataset = dataset.skip(8) return get_iterator_next_for_testing(dataset, self.is_tf2()) self.assertRaises(tf.errors.OutOfRangeError, self.execute, compute_fn=graph_fn_last_batch, inputs=[])
def test_build_tf_record_input_reader_sequence_example(self): label_map_path = _get_labelmap_path() input_reader_text_proto = """ input_type: TF_SEQUENCE_EXAMPLE tf_record_input_reader {} """ input_reader_proto = input_reader_pb2.InputReader() input_reader_proto.label_map_path = label_map_path text_format.Parse(input_reader_text_proto, input_reader_proto) decoder = decoder_builder.build(input_reader_proto) tensor_dict = decoder.decode( self._make_serialized_tf_sequence_example()) with tf.train.MonitoredSession() as sess: output_dict = sess.run(tensor_dict) expected_groundtruth_classes = [[-1, -1], [1, -1], [1, 2], [-1, -1]] expected_groundtruth_boxes = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0], [0.1, 0.1, 0.2, 0.2]], [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] expected_num_groundtruth_boxes = [0, 1, 2, 0] self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks, output_dict) # Sequence example images are encoded. self.assertEqual((4, ), output_dict[fields.InputDataFields.image].shape) self.assertAllEqual( expected_groundtruth_classes, output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEqual( (4, 2, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertAllClose( expected_groundtruth_boxes, output_dict[fields.InputDataFields.groundtruth_boxes]) self.assertAllClose( expected_num_groundtruth_boxes, output_dict[fields.InputDataFields.num_groundtruth_boxes])
def test_build_tf_record_input_reader_and_load_instance_masks(self): input_reader_text_proto = """ load_instance_masks: true tf_record_input_reader {} """ input_reader_proto = input_reader_pb2.InputReader() text_format.Parse(input_reader_text_proto, input_reader_proto) decoder = decoder_builder.build(input_reader_proto) serialized_seq_example = self._make_serialized_tf_example() def graph_fn(): tensor_dict = decoder.decode(serialized_seq_example) return tensor_dict[ fields.InputDataFields.groundtruth_instance_masks] masks = self.execute_cpu(graph_fn, []) self.assertAllEqual((1, 4, 5), masks.shape)
def test_read_dataset_sample_from_datasets_weights_unbalanced(self): """Ensure that the files' values are equally-weighted.""" config = input_reader_pb2.InputReader() config.num_readers = 2 config.shuffle = False config.sample_from_datasets_weights.extend([0.1, 0.9]) def graph_fn(): return self._get_dataset_next( [self._path_template % '0', self._path_template % '1'], config, batch_size=1000) data = list(self.execute(graph_fn, [])) self.assertEqual(len(data), 1000) self._assert_item_count(data, 1, 0.05) self._assert_item_count(data, 10, 0.05) self._assert_item_count(data, 2, 0.45) self._assert_item_count(data, 20, 0.45)
def test_build_tf_record_input_reader_with_batch_size_two(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) def one_hot_class_encoding_fn(tensor_dict): tensor_dict[ fields.InputDataFields.groundtruth_classes] = tf.one_hot( tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3) return tensor_dict tensor_dict = tf.compat.v1.data.make_initializable_iterator( dataset_builder, dataset_builder.build( input_reader_proto, transform_input_data_fn=one_hot_class_encoding_fn, batch_size=2)).get_next() with tf.compat.v1.train.MonitoredSession() as sess: output_dict = sess.run(tensor_dict) self.assertAllEqual([2, 4, 5, 3], output_dict[fields.InputDataFields.image].shape) self.assertAllEqual( [2, 1, 3], output_dict[fields.InputDataFields.groundtruth_classes].shape) self.assertAllEqual( [2, 1, 4], output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertAllEqual( [[[0.0, 0.0, 1.0, 1.0]], [[0.0, 0.0, 1.0, 1.0]]], output_dict[fields.InputDataFields.groundtruth_boxes])
def test_build_tf_record_input_reader(self): input_reader_text_proto = 'tf_record_input_reader {}' input_reader_proto = input_reader_pb2.InputReader() text_format.Parse(input_reader_text_proto, input_reader_proto) decoder = decoder_builder.build(input_reader_proto) serialized_seq_example = self._make_serialized_tf_example() def graph_fn(): tensor_dict = decoder.decode(serialized_seq_example) return (tensor_dict[fields.InputDataFields.image], tensor_dict[fields.InputDataFields.groundtruth_classes], tensor_dict[fields.InputDataFields.groundtruth_boxes]) (image, groundtruth_classes, groundtruth_boxes) = self.execute_cpu(graph_fn, []) self.assertEqual((4, 5, 3), image.shape) self.assertAllEqual([2], groundtruth_classes) self.assertEqual((1, 4), groundtruth_boxes.shape) self.assertAllEqual([0.0, 0.0, 1.0, 1.0], groundtruth_boxes[0])
def test_build_tf_record_input_reader(self): print( '\n=====================================================================' ) print('test_build_tf_record_input_reader') tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) tensor_dict = dataset_util.make_initializable_iterator( dataset=dataset_builder.build( input_reader_config=input_reader_proto, batch_size=1)).get_next() sv = tf.train.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) output_dict = sess.run(tensor_dict) self.assertTrue(fields.InputDataFields.groundtruth_instance_masks not in output_dict) self.assertEquals((1, 4, 5, 3), output_dict[fields.InputDataFields.image].shape) self.assertAllEqual( [[2]], output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEquals( (1, 1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertAllEqual( [0.0, 0.0, 1.0, 1.0], output_dict[fields.InputDataFields.groundtruth_boxes][0][0])
def test_build_tf_record_input_reader_sequence_example(self): label_map_path = _get_labelmap_path() input_reader_text_proto = """ input_type: TF_SEQUENCE_EXAMPLE tf_record_input_reader {} """ input_reader_proto = input_reader_pb2.InputReader() input_reader_proto.label_map_path = label_map_path text_format.Parse(input_reader_text_proto, input_reader_proto) serialized_seq_example = self._make_serialized_tf_sequence_example() def graph_fn(): decoder = decoder_builder.build(input_reader_proto) tensor_dict = decoder.decode(serialized_seq_example) return (tensor_dict[fields.InputDataFields.image], tensor_dict[fields.InputDataFields.groundtruth_classes], tensor_dict[fields.InputDataFields.groundtruth_boxes], tensor_dict[fields.InputDataFields.num_groundtruth_boxes]) (actual_image, actual_groundtruth_classes, actual_groundtruth_boxes, actual_num_groundtruth_boxes) = self.execute_cpu(graph_fn, []) expected_groundtruth_classes = [[-1, -1], [1, -1], [1, 2], [-1, -1]] expected_groundtruth_boxes = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0], [0.1, 0.1, 0.2, 0.2]], [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] expected_num_groundtruth_boxes = [0, 1, 2, 0] # Sequence example images are encoded. self.assertEqual((4, ), actual_image.shape) self.assertAllEqual(expected_groundtruth_classes, actual_groundtruth_classes) self.assertAllClose(expected_groundtruth_boxes, actual_groundtruth_boxes) self.assertAllClose(expected_num_groundtruth_boxes, actual_num_groundtruth_boxes)
def test_build_tf_record_input_reader_and_load_instance_masks(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) tensor_dict = dataset_builder.make_initializable_iterator( dataset_builder.build(input_reader_proto, batch_size=1)).get_next() with tf.train.MonitoredSession() as sess: output_dict = sess.run(tensor_dict) self.assertAllEqual((1, 1, 4, 5), output_dict[ fields.InputDataFields.groundtruth_instance_masks].shape)
def test_build_with_data_augmentation(self): input_reader_proto = input_reader_pb2.InputReader() text_format.Merge( self._get_input_proto('tf_record_video_input_reader'), input_reader_proto) configs = self._get_model_configs_from_proto() data_augmentation_options = [ preprocessor_builder.build( self._get_data_augmentation_preprocessor_proto()) ] tensor_dict = seq_dataset_builder.build( input_reader_proto, configs['model'], configs['lstm_model'], unroll_length=1, data_augmentation_options=data_augmentation_options) all_dict = self._create_training_dict(tensor_dict) self.assertEqual((1, 32, 32, 3), all_dict['image0'].shape) self.assertEqual(4, all_dict['groundtruth_boxes0'].shape[1])
def test_read_dataset_sample_from_datasets_weights_non_normalized(self): """Ensure that the values are equally-weighted when not normalized.""" config = input_reader_pb2.InputReader() config.num_readers = 2 config.shuffle = False # Values are not normalized to sum to 1. In this case, it's a 50/50 split # with each dataset having weight of 1. config.sample_from_datasets_weights.extend([1, 1]) def graph_fn(): return self._get_dataset_next( [self._path_template % '0', self._path_template % '1'], config, batch_size=1000) data = list(self.execute(graph_fn, [])) self.assertEqual(len(data), 1000) self._assert_item_count(data, 1, 0.25) self._assert_item_count(data, 10, 0.25) self._assert_item_count(data, 2, 0.25) self._assert_item_count(data, 20, 0.25)
def test_sample_one_of_n_shards(self): tf_record_path = self.create_tf_record(num_examples=4) input_reader_text_proto = """ shuffle: false num_readers: 1 sample_1_of_n_examples: 2 tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) tensor_dict = dataset_builder.make_initializable_iterator( dataset_builder.build(input_reader_proto, batch_size=1)).get_next() with tf.train.MonitoredSession() as sess: output_dict = sess.run(tensor_dict) self.assertAllEqual(['0'], output_dict[fields.InputDataFields.source_id]) output_dict = sess.run(tensor_dict) self.assertEquals(['2'], output_dict[fields.InputDataFields.source_id])
def test_build_tf_record_input_reader_and_load_instance_masks(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) tensor_dict = dataset_util.make_initializable_iterator( dataset_builder.build(input_reader_proto)).get_next() sv = tf.train.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) output_dict = sess.run(tensor_dict) self.assertAllEqual((1, 4, 5), output_dict[ fields.InputDataFields.groundtruth_instance_masks].shape)