Ejemplo n.º 1
0
def get_non_differentiable_rectangle_depth_estimation(reference_pose_torch,
                                                      measurement_pose_torch,
                                                      previous_depth_torch,
                                                      full_K_torch,
                                                      half_K_torch,
                                                      original_width,
                                                      original_height):
    batch_size, _, _ = reference_pose_torch.shape
    half_width = int(original_width / 2)
    half_height = int(original_height / 2)

    trans = torch.bmm(torch.inverse(reference_pose_torch), measurement_pose_torch)
    points_3d_src = kornia.depth_to_3d(previous_depth_torch, full_K_torch, normalize_points=False)
    points_3d_src = points_3d_src.permute(0, 2, 3, 1)
    points_3d_dst = kornia.transform_points(trans[:, None], points_3d_src)

    points_3d_dst = points_3d_dst.view(batch_size, -1, 3)

    z_values = points_3d_dst[:, :, -1]
    z_values = torch.relu(z_values)
    sorting_indices = torch.argsort(z_values, descending=True)
    z_values = torch.gather(z_values, dim=1, index=sorting_indices)

    sorting_indices_for_points = torch.stack([sorting_indices] * 3, dim=-1)
    points_3d_dst = torch.gather(points_3d_dst, dim=1, index=sorting_indices_for_points)

    projections = torch.round(kornia.project_points(points_3d_dst, half_K_torch.unsqueeze(1))).long()
    is_valid_below = (projections[:, :, 0] >= 0) & (projections[:, :, 1] >= 0)
    is_valid_above = (projections[:, :, 0] < half_width) & (projections[:, :, 1] < half_height)
    is_valid = is_valid_below & is_valid_above

    depth_hypothesis = torch.zeros(size=(batch_size, 1, half_height, half_width)).cuda()
    for projection_index in range(0, batch_size):
        valid_points_zs = z_values[projection_index][is_valid[projection_index]]
        valid_projections = projections[projection_index][is_valid[projection_index]]
        i_s = valid_projections[:, 1]
        j_s = valid_projections[:, 0]
        ij_combined = i_s * half_width + j_s
        _, ij_combined_unique_indices = np.unique(ij_combined.cpu().numpy(), return_index=True)
        ij_combined_unique_indices = torch.from_numpy(ij_combined_unique_indices).long().cuda()
        i_s = i_s[ij_combined_unique_indices]
        j_s = j_s[ij_combined_unique_indices]
        valid_points_zs = valid_points_zs[ij_combined_unique_indices]
        torch.index_put_(depth_hypothesis[projection_index, 0], (i_s, j_s), valid_points_zs)
    return depth_hypothesis
Ejemplo n.º 2
0
    def test_invalid_device(self, device):
        idx = torch.tensor([0, 1])
        b = torch.zeros(5, device=device)
        c = torch.tensor([1., 2.], device="cpu")

        for accumulate in [True, False]:
            self.assertRaisesRegex(
                RuntimeError, 'expected device',
                lambda: torch.index_put_(b, (idx, ), c, accumulate=accumulate))
Ejemplo n.º 3
0
    def test_hacked_twin(self):
        def gen_data():
            with freeze_rng_state():
                return torch.randn(10), torch.randint(10,
                                                      (20, )), torch.randn(20)

        input, index, value, = gen_data()
        input1, index1, value1, = gen_data()
        out1 = torch.ops.aten.index_put.hacked_twin(input, [index],
                                                    value,
                                                    accumulate=False)
        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
        self.assertEqual(out1, out2)

        torch.ops.aten.index_put_.hacked_twin(input, [index],
                                              value,
                                              accumulate=False)
        torch.index_put_(input1, [index1], value1, accumulate=False)
        self.assertEqual(input, input1)