def test_dynamic_get_attribute_value(int_dtype, fp_dtype): attributes = { "attrs.num_classes": int_dtype(85), "attrs.background_label_id": int_dtype(13), "attrs.top_k": int_dtype(16), "attrs.variance_encoded_in_target": True, "attrs.keep_top_k": np.array([64, 32, 16, 8], dtype=int_dtype), "attrs.code_type": "pytorch.some_parameter_name", "attrs.share_location": False, "attrs.nms_threshold": fp_dtype(0.645), "attrs.confidence_threshold": fp_dtype(0.111), "attrs.clip_after_nms": True, "attrs.clip_before_nms": False, "attrs.decrease_label_id": True, "attrs.normalized": True, "attrs.input_height": int_dtype(86), "attrs.input_width": int_dtype(79), "attrs.objectness_score": fp_dtype(0.77), } box_logits = ng.parameter([4, 1, 5, 5], fp_dtype, "box_logits") class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "class_preds") proposals = ng.parameter([2, 1, 4, 5], fp_dtype, "proposals") aux_class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_class_preds") aux_box_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_box_preds") node = ng.detection_output(box_logits, class_preds, proposals, attributes, aux_class_preds, aux_box_preds) assert node.get_num_classes() == int_dtype(85) assert node.get_background_label_id() == int_dtype(13) assert node.get_top_k() == int_dtype(16) assert node.get_variance_encoded_in_target() assert np.all( np.equal(node.get_keep_top_k(), np.array([64, 32, 16, 8], dtype=int_dtype))) assert node.get_code_type() == "pytorch.some_parameter_name" assert not node.get_share_location() assert np.isclose(node.get_nms_threshold(), fp_dtype(0.645)) assert np.isclose(node.get_confidence_threshold(), fp_dtype(0.111)) assert node.get_clip_after_nms() assert not node.get_clip_before_nms() assert node.get_decrease_label_id() assert node.get_normalized() assert node.get_input_height() == int_dtype(86) assert node.get_input_width() == int_dtype(79) assert np.isclose(node.get_objectness_score(), fp_dtype(0.77)) assert node.get_num_classes() == int_dtype(85)
def test_dynamic_get_attribute_value(int_dtype, fp_dtype): attributes = { "num_classes": int_dtype(85), "background_label_id": int_dtype(13), "top_k": int_dtype(16), "variance_encoded_in_target": True, "keep_top_k": np.array([64, 32, 16, 8], dtype=int_dtype), "code_type": "caffe.PriorBoxParameter.CENTER_SIZE", "share_location": False, "nms_threshold": fp_dtype(0.645), "confidence_threshold": fp_dtype(0.111), "clip_after_nms": True, "clip_before_nms": False, "decrease_label_id": True, "normalized": True, "input_height": int_dtype(86), "input_width": int_dtype(79), "objectness_score": fp_dtype(0.77), } box_logits = ng.parameter([4, 680], fp_dtype, "box_logits") class_preds = ng.parameter([4, 170], fp_dtype, "class_preds") proposals = ng.parameter([4, 1, 8], fp_dtype, "proposals") aux_class_preds = ng.parameter([4, 4], fp_dtype, "aux_class_preds") aux_box_preds = ng.parameter([4, 680], fp_dtype, "aux_box_preds") node = ng.detection_output(box_logits, class_preds, proposals, attributes, aux_class_preds, aux_box_preds) assert node.get_num_classes() == int_dtype(85) assert node.get_background_label_id() == int_dtype(13) assert node.get_top_k() == int_dtype(16) assert node.get_variance_encoded_in_target() assert np.all( np.equal(node.get_keep_top_k(), np.array([64, 32, 16, 8], dtype=int_dtype))) assert node.get_code_type() == "caffe.PriorBoxParameter.CENTER_SIZE" assert not node.get_share_location() assert np.isclose(node.get_nms_threshold(), fp_dtype(0.645)) assert np.isclose(node.get_confidence_threshold(), fp_dtype(0.111)) assert node.get_clip_after_nms() assert not node.get_clip_before_nms() assert node.get_decrease_label_id() assert node.get_normalized() assert node.get_input_height() == int_dtype(86) assert node.get_input_width() == int_dtype(79) assert np.isclose(node.get_objectness_score(), fp_dtype(0.77)) assert node.get_num_classes() == int_dtype(85)
def test_detection_output(int_dtype, fp_dtype): attributes = { "keep_top_k": np.array([64], dtype=int_dtype), "nms_threshold": fp_dtype(0.645), } box_logits = ng.parameter([4, 8], fp_dtype, "box_logits") class_preds = ng.parameter([4, 170], fp_dtype, "class_preds") proposals = ng.parameter([4, 2, 10], fp_dtype, "proposals") aux_class_preds = ng.parameter([4, 4], fp_dtype, "aux_class_preds") aux_box_preds = ng.parameter([4, 8], fp_dtype, "aux_box_preds") node = ng.detection_output(box_logits, class_preds, proposals, attributes, aux_class_preds, aux_box_preds) assert node.get_type_name() == "DetectionOutput" assert node.get_output_size() == 1 assert list(node.get_output_shape(0)) == [1, 1, 256, 7]