def test_heatmaps_to_keypoints(self): maps = torch.rand(10, 1, 26, 26) rois = torch.rand(10, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints out = heatmaps_to_keypoints(maps, rois) jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) out_trace = jit_trace(maps, rois) assert_equal(out[0], out_trace[0]) assert_equal(out[1], out_trace[1]) maps2 = torch.rand(20, 2, 21, 21) rois2 = torch.rand(20, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints out2 = heatmaps_to_keypoints(maps2, rois2) out_trace2 = jit_trace(maps2, rois2) assert_equal(out2[0], out_trace2[0]) assert_equal(out2[1], out_trace2[1])
def test_heatmaps_to_keypoints(self): # disable profiling torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) maps = torch.rand(10, 1, 26, 26) rois = torch.rand(10, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints out = heatmaps_to_keypoints(maps, rois) jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) out_trace = jit_trace(maps, rois) assert torch.all(out[0].eq(out_trace[0])) assert torch.all(out[1].eq(out_trace[1])) maps2 = torch.rand(20, 2, 21, 21) rois2 = torch.rand(20, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints out2 = heatmaps_to_keypoints(maps2, rois2) out_trace2 = jit_trace(maps2, rois2) assert torch.all(out2[0].eq(out_trace2[0])) assert torch.all(out2[1].eq(out_trace2[1]))