コード例 #1
0
    def test_splatting_type_values(self):
        frame = torch.tensor([1, 2], dtype=torch.float32).reshape([1, 1, 1, 2])
        flow = torch.zeros([1, 2, 1, 2], dtype=torch.float32)
        flow[0, 0, 0, 0] = 1
        importance_metric = torch.tensor([1, 2], dtype=torch.float32).reshape(
            [1, 1, 1, 2])

        # summation splatting
        output = splatting.splatting_function("summation", frame, flow)
        assert output[0, 0, 0, 1] == pytest.approx(3)

        # average splatting
        output = splatting.splatting_function("average", frame, flow)
        assert output[0, 0, 0, 1] == pytest.approx(1.5)

        # linear splatting
        output = splatting.splatting_function("linear", frame, flow,
                                              importance_metric)
        assert output[0, 0, 0, 1] == pytest.approx(5.0 / 3.0)

        # softmax splatting
        output = splatting.splatting_function("softmax", frame, flow,
                                              importance_metric)
        assert output[0, 0, 0, 1] == pytest.approx(
            (exp(1) + 2 * exp(2)) / (exp(1) + exp(2)))
コード例 #2
0
 def test_splatting_type_names(self):
     frame = torch.zeros(1, 1, 3, 3)
     flow = torch.zeros(1, 2, 3, 3)
     importance_metric = torch.ones_like(frame)
     splatting.splatting_function("summation", frame, flow)
     splatting.splatting_function("average", frame, flow)
     splatting.splatting_function("linear", frame, flow, importance_metric)
     splatting.splatting_function("softmax", frame, flow, importance_metric)
     with pytest.raises(NotImplementedError):
         splatting.splatting_function("something_else", frame, flow,
                                      importance_metric)
コード例 #3
0
 def test_flow_one(self):
     frame = torch.zeros(1, 1, 3, 3)
     frame[0, :, 0, 0] = 1
     flow = torch.zeros(1, 2, 3, 3)
     flow[0, :, 0, 0] = 1
     target = torch.zeros(1, 1, 3, 3)
     target[0, :, 1, 1] = 1
     output = splatting.splatting_function("summation", frame, flow)
     assert torch.equal(output, target)
コード例 #4
0
def render_forward(src_ims, src_dms, Rcam, tcam, K_src, K_dst):
    Rcam = Rcam.to(device=src_ims.device)[None]
    tcam = tcam.to(device=src_ims.device)[None]

    R = Rcam
    t = tcam[..., None]
    K_src_inv = K_src.inverse()

    assert len(src_ims.shape) == 4
    assert len(src_dms.shape) == 3
    assert src_ims.shape[1:3] == src_dms.shape[1:3], (src_ims.shape,
                                                      src_dms.shape)

    x = np.arange(src_ims[0].shape[1])
    y = np.arange(src_ims[0].shape[0])
    coord = np.stack(np.meshgrid(x, y), -1)
    coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1)  # z=1
    coord = coord.astype(np.float32)
    coord = torch.as_tensor(coord, dtype=K_src.dtype, device=K_src.device)
    coord = coord[None]  # bs, h, w, 3

    D = src_dms[:, :, :, None, None]

    points = K_dst[None, None, None, ...] @ (R[:, None, None, ...] @ (
        D * K_src_inv[None, None, None, ...] @ coord[:, :, :, :, None]) +
                                             t[:, None, None, :, :])
    points = points.squeeze(-1)

    new_z = points[:, :, :, [2]].clone().permute(0, 3, 1, 2)  # b,1,h,w
    points = points / torch.clamp(points[:, :, :, [2]], 1e-8, None)

    src_ims = src_ims.permute(0, 3, 1, 2)
    flow = points - coord
    flow = flow.permute(0, 3, 1, 2)[:, :2, ...]

    alpha = 0.5
    importance = alpha / new_z
    importance_min = importance.amin((1, 2, 3), keepdim=True)
    importance_max = importance.amax((1, 2, 3), keepdim=True)
    importance = (importance - importance_min) / (
        importance_max - importance_min + 1e-6) * 10 - 10
    importance = importance.exp()

    input_data = torch.cat([importance * src_ims, importance], 1)
    output_data = splatting_function("summation", input_data, flow)

    num = torch.sum(output_data[:, :-1, :, :], dim=0, keepdim=True)
    nom = torch.sum(output_data[:, -1:, :, :], dim=0, keepdim=True)

    rendered = num / (nom + 1e-7)
    rendered = rendered.permute(0, 2, 3, 1)[0, ...]
    return rendered
コード例 #5
0
    def test_importance_metric_type_and_shape(self):
        frame = torch.ones([1, 1, 3, 3])
        flow = torch.zeros([1, 2, 3, 3])
        importance_metric = frame.new_ones([1, 1, 3, 3])
        wrong_metric_0 = frame.new_ones([2, 1, 3, 3])
        wrong_metric_1 = frame.new_ones([1, 2, 3, 3])
        wrong_metric_2 = frame.new_ones([1, 1, 2, 3])
        wrong_metric_3 = frame.new_ones([1, 1, 3, 2])

        # summation splatting
        splatting.splatting_function("summation", frame, flow)
        with pytest.raises(AssertionError):
            splatting.splatting_function("summation", frame, flow,
                                         importance_metric)

        # average splatting
        splatting.splatting_function("average", frame, flow)
        with pytest.raises(AssertionError):
            splatting.splatting_function("average", frame, flow,
                                         importance_metric)

        # linear splatting
        splatting.splatting_function("linear", frame, flow, importance_metric)
        with pytest.raises(AssertionError):
            splatting.splatting_function("linear", frame, flow)
        with pytest.raises(AssertionError):
            splatting.splatting_function("linear", frame, flow, wrong_metric_0)
        with pytest.raises(AssertionError):
            splatting.splatting_function("linear", frame, flow, wrong_metric_1)
        with pytest.raises(AssertionError):
            splatting.splatting_function("linear", frame, flow, wrong_metric_2)
        with pytest.raises(AssertionError):
            splatting.splatting_function("linear", frame, flow, wrong_metric_3)

        # softmax splatting
        splatting.splatting_function("softmax", frame, flow, importance_metric)
        with pytest.raises(AssertionError):
            splatting.splatting_function("softmax", frame, flow)
        with pytest.raises(AssertionError):
            splatting.splatting_function("softmax", frame, flow,
                                         wrong_metric_0)
        with pytest.raises(AssertionError):
            splatting.splatting_function("softmax", frame, flow,
                                         wrong_metric_1)
        with pytest.raises(AssertionError):
            splatting.splatting_function("softmax", frame, flow,
                                         wrong_metric_2)
        with pytest.raises(AssertionError):
            splatting.splatting_function("softmax", frame, flow,
                                         wrong_metric_3)
コード例 #6
0
ファイル: benchmark.py プロジェクト: ywu40/splatting
def run_test_backward(method, batch_size, spatial_size, flow_init, repetitions=10):
    frame = torch.ones(batch_size, 3, spatial_size, spatial_size)
    if flow_init == "zeros":
        flow = torch.zeros(batch_size, 2, spatial_size, spatial_size)
    elif flow_init == "ones":
        flow = torch.ones(batch_size, 2, spatial_size, spatial_size)
    else:
        raise NotImplementedError
    if method == "splatting_cpu":
        import splatting.cpu

        grad_output = torch.zeros_like(frame)
        grad_frame = torch.zeros_like(frame)
        grad_flow = torch.zeros_like(flow)

        def test_fn():
            splatting.cpu.splatting_backward_cpu(
                frame, flow, grad_output, grad_frame, grad_flow
            )

    elif method == "splatting_cuda":
        import splatting.cuda

        frame = frame.cuda()
        flow = flow.cuda()
        grad_output = torch.zeros_like(frame)
        grad_frame = torch.zeros_like(frame)
        grad_flow = torch.zeros_like(flow)

        def test_fn():
            splatting.cuda.splatting_backward_cuda(
                frame, flow, grad_output, grad_frame, grad_flow
            )
            torch.cuda.synchronize()

    elif method == "splatting_function":
        import splatting

        frame.requires_grad_(True)
        flow.requires_grad_(True)
        output = splatting.SummationSplattingFunction.apply(frame, flow).sum()

        def test_fn():
            output.backward(retain_graph=True)

    elif method == "splatting_function_summation":
        import splatting

        frame.requires_grad_(True)
        flow.requires_grad_(True)
        output = splatting.splatting_function("summation", frame, flow).sum()

        def test_fn():
            output.backward(retain_graph=True)

    elif method == "splatting_module_summation":
        import splatting

        frame.requires_grad_(True)
        flow.requires_grad_(True)
        splatting_module = splatting.Splatting("summation")
        output = splatting_module(frame, flow).sum()

        def test_fn():
            output.backward(retain_graph=True)

    elif method == "splatting_module_softmax":
        import splatting

        frame.requires_grad_(True)
        flow.requires_grad_(True)
        importance_metric = frame.new_ones(
            [frame.shape[0], 1, frame.shape[2], frame.shape[3]]
        )
        splatting_module = splatting.Splatting("softmax")
        output = splatting_module(frame, flow, importance_metric).sum()

        def test_fn():
            output.backward(retain_graph=True)

    else:
        raise NotImplementedError(f"method {method}")
    ex_time = (
        timeit.timeit(
            test_fn,
            number=repetitions,
        )
        / repetitions
    )
    print(
        f"backward \tbatch_size={batch_size}\tspatial_size={spatial_size}\t"
        + f"flow_init={flow_init}\tex_time={ex_time}"
    )
コード例 #7
0
ファイル: benchmark.py プロジェクト: ywu40/splatting
 def test_fn():
     splatting.splatting_function("summation", frame, flow)