Ejemplo n.º 1
0
    def _compute_acceleration(self, vert_t, vert_tp1, vert_0, neigh, acc, step_size):
        n_vert = vert_t.shape[0]

        vert_diff_t = vert_t[neigh[:, 0], :] - vert_t[neigh[:, 1], :]
        vert_diff_0 = vert_0[neigh[:, 0], :] - vert_0[neigh[:, 1], :]

        vert_diff_tp1 = vert_tp1[neigh[:, 0], :] - vert_tp1[neigh[:, 1], :]

        R_tp1 = arap(vert_diff_tp1, vert_diff_0, neigh, n_vert)

        R_neigh_tp1 = 0.5 * (torch.index_select(R_tp1, 0, neigh[:, 0]) + torch.index_select(R_tp1, 0, neigh[:, 1]))

        vert_diff_0_rot = torch.bmm(R_neigh_tp1, vert_diff_0.unsqueeze(2)).squeeze()
        acc_t_neigh = vert_diff_t - vert_diff_0_rot

        R_tp1 = torch.bmm(R_tp1, acc.transpose(1, 2))

        R_neigh_tp1 = 0.5 * (torch.index_select(R_tp1, 0, neigh[:, 0]) + torch.index_select(R_tp1, 0, neigh[:, 1]))
        acc_t_neigh = torch.bmm(R_neigh_tp1.transpose(1, 2), acc_t_neigh.unsqueeze(2)).squeeze()

        acc_t = torch.zeros(vert_t.shape, dtype=torch.float32, device=device)
        acc_t = torch.index_add(acc_t, 0, neigh[:, 0], -acc_t_neigh)
        acc_t = torch.index_add(acc_t, 0, neigh[:, 1], acc_t_neigh)

        acc_t = torch.bmm(R_tp1, acc_t.unsqueeze(2)).squeeze()

        return acc_t, vert_diff_0_rot, R_tp1
Ejemplo n.º 2
0
def arap(vert_diff_t, vert_diff_0, neigh, n_vert):
    norm_t = vert_diff_t.norm(dim=1, keepdim=True) + 1e-6
    norm_0 = vert_diff_0.norm(dim=1, keepdim=True) + 1e-6

    weight = norm_0 * norm_t

    vert_diff_0 = vert_diff_0 / norm_0
    vert_diff_t = vert_diff_t / norm_t

    cross = torch.bmm(hat_op(vert_diff_0), vert_diff_t.unsqueeze(2)).squeeze()

    # compute per-point quantities
    cross_pp = my_zeros([n_vert, 3])
    weight_pp = my_zeros([n_vert, 1])

    weight_pp = torch.index_add(weight_pp, 0, neigh[:, 0], weight)
    weight_pp = torch.index_add(weight_pp, 0, neigh[:, 1], weight)

    cross = cross * weight
    cross_pp = torch.index_add(cross_pp, 0, neigh[:, 0], cross)
    cross_pp = torch.index_add(cross_pp, 0, neigh[:, 1], cross)
    cross_pp = cross_pp / weight_pp

    # compute rotation matrix
    hat_cross_pp = hat_op(cross_pp)

    sin_alpha = hat_cross_pp.norm(dim=1, keepdim=True)
    cos_alpha = torch.sqrt(torch.abs(1 - sin_alpha**2))

    R = my_eye(3).unsqueeze(0) + hat_cross_pp + torch.bmm(
        hat_cross_pp, hat_cross_pp) * 1 / (1 + cos_alpha)

    return R
Ejemplo n.º 3
0
 def indexed_softmax(N, x, adj):
     # x must be of shape [E, 1], adj must be of shape [E, 2]
     # prevent infinity caused nan
     x_exp = x.exp().clamp(0, 1e6)
     denom = torch.index_add(torch.full([N, 1], 1e-10,
                                        dtype=torch.float32,
                                        device=x.device),
                             dim=0, index=adj[:, 0], source=x_exp)
     denom = denom[adj[:, 0]]
     return x_exp / denom
Ejemplo n.º 4
0
        def run_test(batch_size):
            B0 = batch_size
            x = torch.randn(B0, 7, 11, 13)
            dim = 0
            index = torch.tensor([0, 4, 2])
            values = torch.randn(B0, 3, 13)

            self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values))

            result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
            expected = torch.index_add(
                x, dim + 1, index, values.view(B0, 3, 1, 13))
            self.assertEqual(result, expected)
Ejemplo n.º 5
0
def compute_outer_normal(vert, triv, samples):
    edge_1 = torch.index_select(vert, 0, triv[:, 1]) - torch.index_select(
        vert, 0, triv[:, 0])
    edge_2 = torch.index_select(vert, 0, triv[:, 2]) - torch.index_select(
        vert, 0, triv[:, 0])

    face_norm = torch.cross(1e4 * edge_1, 1e4 * edge_2)

    normal = my_zeros(vert.shape)
    for d in range(3):
        normal = torch.index_add(normal, 0, triv[:, d], face_norm)
    normal = normal / (1e-5 + normal.norm(dim=1, keepdim=True))

    return normal[samples, :]
Ejemplo n.º 6
0
        def run_test(batch_size):
            B0 = batch_size
            x = torch.randn(B0, 7, 11, 13)
            dim = 0
            index = torch.tensor([0, 4, 2])
            values = torch.randn(B0, 3, 13)

            with warnings.catch_warnings(record=True) as wa:
                result = vmap(torch.index_add,
                              (0, None, None, 0))(x, dim, index, values)
                self.assertEqual(len(wa), 2)
                self.assertRegex(
                    str(wa[-1].message),
                    r'falling back to slow \(for loop and stack\) implementation'
                )
                expected = torch.index_add(x, dim + 1, index,
                                           values.view(B0, 3, 1, 13))
                self.assertEqual(result, expected)
Ejemplo n.º 7
0
 def indexed_multiply_and_gather(x, Wh, adj):
     new_expanded_Wh = Wh[adj[:, 1]] * x
     return torch.index_add(torch.zeros_like(Wh), dim=0,
                            index=adj[:, 0], source=new_expanded_Wh)