def _assert_model_fn_for_predict(self, configs): model_config = configs['model'] with tf.Graph().as_default(): features, _ = _make_initializable_iterator( inputs.create_eval_input_fn(configs['eval_config'], configs['eval_input_config'], configs['model'])()).get_next() detection_model_fn = functools.partial(model_builder.build, model_config=model_config, is_training=False) hparams = model_hparams.create_hparams( hparams_overrides='load_pretrained=false') model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams) estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT) self.assertIsNone(estimator_spec.loss) self.assertIsNone(estimator_spec.train_op) self.assertIsNotNone(estimator_spec.predictions) self.assertIsNotNone(estimator_spec.export_outputs) self.assertIn( tf.saved_model.signature_constants.PREDICT_METHOD_NAME, estimator_spec.export_outputs)
def test_error_with_bad_eval_model_config(self): """Tests that a TypeError is raised with improper eval model config.""" configs = _get_configs_for_model('ssd_inception_v2_pets') configs['model'].ssd.num_classes = 37 eval_input_fn = inputs.create_eval_input_fn( eval_config=configs['eval_config'], eval_input_config=configs['eval_input_config'], model_config=configs['eval_config']) # Expecting `DetectionModel`. with self.assertRaises(TypeError): eval_input_fn()
def test_faster_rcnn_resnet50_eval_input(self): """Tests the eval input function for FasterRcnnResnet50.""" configs = _get_configs_for_model('faster_rcnn_resnet50_pets') model_config = configs['model'] model_config.faster_rcnn.num_classes = 37 eval_input_fn = inputs.create_eval_input_fn( configs['eval_config'], configs['eval_input_config'], model_config) features, labels = _make_initializable_iterator(eval_input_fn()).get_next() self.assertAllEqual([1, None, None, 3], features[fields.InputDataFields.image].shape.as_list()) self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype) self.assertAllEqual( [1, None, None, 3], features[fields.InputDataFields.original_image].shape.as_list()) self.assertEqual(tf.uint8, features[fields.InputDataFields.original_image].dtype) self.assertAllEqual([1], features[inputs.HASH_KEY].shape.as_list()) self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype) self.assertAllEqual( [1, 100, 4], labels[fields.InputDataFields.groundtruth_boxes].shape.as_list()) self.assertEqual(tf.float32, labels[fields.InputDataFields.groundtruth_boxes].dtype) self.assertAllEqual( [1, 100, model_config.faster_rcnn.num_classes], labels[fields.InputDataFields.groundtruth_classes].shape.as_list()) self.assertEqual(tf.float32, labels[fields.InputDataFields.groundtruth_classes].dtype) self.assertAllEqual( [1, 100], labels[fields.InputDataFields.groundtruth_area].shape.as_list()) self.assertEqual(tf.float32, labels[fields.InputDataFields.groundtruth_area].dtype) self.assertAllEqual( [1, 100], labels[fields.InputDataFields.groundtruth_is_crowd].shape.as_list()) self.assertEqual( tf.bool, labels[fields.InputDataFields.groundtruth_is_crowd].dtype) self.assertAllEqual( [1, 100], labels[fields.InputDataFields.groundtruth_difficult].shape.as_list()) self.assertEqual( tf.int32, labels[fields.InputDataFields.groundtruth_difficult].dtype)
def _assert_model_fn_for_train_eval(self, configs, mode, class_agnostic=False): model_config = configs['model'] train_config = configs['train_config'] with tf.Graph().as_default(): if mode == 'train': features, labels = _make_initializable_iterator( inputs.create_train_input_fn( configs['train_config'], configs['train_input_config'], configs['model'])()).get_next() model_mode = tf.estimator.ModeKeys.TRAIN batch_size = train_config.batch_size elif mode == 'eval': features, labels = _make_initializable_iterator( inputs.create_eval_input_fn( configs['eval_config'], configs['eval_input_config'], configs['model'])()).get_next() model_mode = tf.estimator.ModeKeys.EVAL batch_size = 1 elif mode == 'eval_on_train': features, labels = _make_initializable_iterator( inputs.create_eval_input_fn( configs['eval_config'], configs['train_input_config'], configs['model'])()).get_next() model_mode = tf.estimator.ModeKeys.EVAL batch_size = 1 detection_model_fn = functools.partial(model_builder.build, model_config=model_config, is_training=True) hparams = model_hparams.create_hparams( hparams_overrides='load_pretrained=false') model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams) estimator_spec = model_fn(features, labels, model_mode) self.assertIsNotNone(estimator_spec.loss) self.assertIsNotNone(estimator_spec.predictions) if mode == 'eval' or mode == 'eval_on_train': if class_agnostic: self.assertNotIn('detection_classes', estimator_spec.predictions) else: detection_classes = estimator_spec.predictions[ 'detection_classes'] self.assertEqual(batch_size, detection_classes.shape.as_list()[0]) self.assertEqual(tf.float32, detection_classes.dtype) detection_boxes = estimator_spec.predictions['detection_boxes'] detection_scores = estimator_spec.predictions[ 'detection_scores'] num_detections = estimator_spec.predictions['num_detections'] self.assertEqual(batch_size, detection_boxes.shape.as_list()[0]) self.assertEqual(tf.float32, detection_boxes.dtype) self.assertEqual(batch_size, detection_scores.shape.as_list()[0]) self.assertEqual(tf.float32, detection_scores.dtype) self.assertEqual(tf.float32, num_detections.dtype) if model_mode == tf.estimator.ModeKeys.TRAIN: self.assertIsNotNone(estimator_spec.train_op) return estimator_spec