def test_raises_error_with_no_input_paths(self): input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true """ input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) with self.assertRaises(ValueError): dataset_builder.build(input_reader_proto, batch_size=1)
def test_build_tf_record_input_reader(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) tensor_dict = dataset_builder.make_initializable_iterator( dataset_builder.build(input_reader_proto, batch_size=1)).get_next() sv = tf.train.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) output_dict = sess.run(tensor_dict) self.assertTrue(fields.InputDataFields.groundtruth_instance_masks not in output_dict) self.assertEquals((1, 4, 5, 3), output_dict[fields.InputDataFields.image].shape) self.assertAllEqual( [[2]], output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEquals( (1, 1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertAllEqual( [0.0, 0.0, 1.0, 1.0], output_dict[fields.InputDataFields.groundtruth_boxes][0][0])
def test_build_tf_record_input_reader_with_batch_size_two_and_masks(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) def one_hot_class_encoding_fn(tensor_dict): tensor_dict[ fields.InputDataFields.groundtruth_classes] = tf.one_hot( tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3) return tensor_dict tensor_dict = dataset_builder.make_initializable_iterator( dataset_builder.build( input_reader_proto, transform_input_data_fn=one_hot_class_encoding_fn, batch_size=2)).get_next() sv = tf.train.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) output_dict = sess.run(tensor_dict) self.assertAllEqual([2, 1, 4, 5], output_dict[ fields.InputDataFields.groundtruth_instance_masks].shape)
def get_next(config): return dataset_builder.make_initializable_iterator( dataset_builder.build(config)).get_next()