def __init__(self):
     super().__init__()
     self.model = KRCNNConvDeconvUpsampleHead(ShapeSpec(channels=4,
                                                        height=14,
                                                        width=14),
                                              num_keypoints=17,
                                              conv_dims=(4, ))
Ejemplo n.º 2
0
    def test_keypoint_head_scriptability(self):
        input_shape = ShapeSpec(channels=1024, height=14, width=14)
        keypoint_features = torch.randn(4, 1024, 14, 14)

        image_shapes = [(10, 10), (15, 15)]
        pred_boxes0 = torch.tensor([[1, 1, 3, 3], [2, 2, 6, 6], [1, 5, 2, 8]], dtype=torch.float32)
        pred_instance0 = Instances(image_shapes[0])
        pred_instance0.pred_boxes = Boxes(pred_boxes0)
        pred_boxes1 = torch.tensor([[7, 3, 10, 5]], dtype=torch.float32)
        pred_instance1 = Instances(image_shapes[1])
        pred_instance1.pred_boxes = Boxes(pred_boxes1)

        keypoint_head = KRCNNConvDeconvUpsampleHead(
            input_shape, num_keypoints=17, conv_dims=[512, 512]
        ).eval()
        origin_outputs = keypoint_head(
            keypoint_features, deepcopy([pred_instance0, pred_instance1])
        )

        fields = {
            "pred_boxes": Boxes,
            "pred_keypoints": torch.Tensor,
            "pred_keypoint_heatmaps": torch.Tensor,
        }
        with freeze_training_mode(keypoint_head), patch_instances(fields) as NewInstances:
            sciript_keypoint_head = torch.jit.script(keypoint_head)
            pred_instance0 = NewInstances.from_instances(pred_instance0)
            pred_instance1 = NewInstances.from_instances(pred_instance1)
            script_outputs = sciript_keypoint_head(
                keypoint_features, [pred_instance0, pred_instance1]
            )

        for origin_ins, script_ins in zip(origin_outputs, script_outputs):
            assert_instances_allclose(origin_ins, script_ins.to_instances(), rtol=0)