def __init__(self): super().__init__() self.model = KRCNNConvDeconvUpsampleHead(ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4, ))
def test_keypoint_head_scriptability(self): input_shape = ShapeSpec(channels=1024, height=14, width=14) keypoint_features = torch.randn(4, 1024, 14, 14) image_shapes = [(10, 10), (15, 15)] pred_boxes0 = torch.tensor([[1, 1, 3, 3], [2, 2, 6, 6], [1, 5, 2, 8]], dtype=torch.float32) pred_instance0 = Instances(image_shapes[0]) pred_instance0.pred_boxes = Boxes(pred_boxes0) pred_boxes1 = torch.tensor([[7, 3, 10, 5]], dtype=torch.float32) pred_instance1 = Instances(image_shapes[1]) pred_instance1.pred_boxes = Boxes(pred_boxes1) keypoint_head = KRCNNConvDeconvUpsampleHead( input_shape, num_keypoints=17, conv_dims=[512, 512] ).eval() origin_outputs = keypoint_head( keypoint_features, deepcopy([pred_instance0, pred_instance1]) ) fields = { "pred_boxes": Boxes, "pred_keypoints": torch.Tensor, "pred_keypoint_heatmaps": torch.Tensor, } with freeze_training_mode(keypoint_head), patch_instances(fields) as NewInstances: sciript_keypoint_head = torch.jit.script(keypoint_head) pred_instance0 = NewInstances.from_instances(pred_instance0) pred_instance1 = NewInstances.from_instances(pred_instance1) script_outputs = sciript_keypoint_head( keypoint_features, [pred_instance0, pred_instance1] ) for origin_ins, script_ins in zip(origin_outputs, script_outputs): assert_instances_allclose(origin_ins, script_ins.to_instances(), rtol=0)