Exemplo n.º 1
0
def lc_consistence_loss(axis, neighbor, idx):
    """
    axis : BNC3
    distance weighted consistence?
    """
    B, N, _, _ = axis.shape
    axis = axis.reshape(B, N, 3 * 3)
    grouped_axis = index_points(axis, neighbor).reshape(B * N, -1, 3, 3)

    x = grouped_axis[:, :, :, 0]  # BN, K, 3
    y = grouped_axis[:, :, :, 1]  # BN, K, 3
    # z = grouped_axis[:, :, :, 2]  # BN, K, 3
    x = torch.matmul(x, x.permute(0, 2, 1))  # BN, K, K
    y = torch.matmul(y, y.permute(0, 2, 1))  # BN, K, K
    # z = torch.matmul(z, z.permute(0, 2, 1))  # BN, K, K
    threshold = torch.cos(torch.tensor(15*(idx+1)*3.141592653/180.0))   # 30 deg
    res = torch.tensor(0.0, device=x.device)
    x_selected = x[torch.where(x < threshold)]
    if x_selected.shape[0] != 0:
        res += x_selected.mean()
    y_selected = y[torch.where(y < threshold)]
    if y_selected.shape[0] != 0:
        res += y_selected.mean()
    # z_selected = z[torch.where(z < threshold)]
    # if z_selected.shape[0] != 0:
    #    res += z_selected.mean()

    # res = torch.matmul(x, x.permute(0, 2, 1)).mean() +\
    #      torch.matmul(y, y.permute(0, 2, 1)).mean() +\
    #      torch.matmul(z, z.permute(0, 2, 1)).mean()
    return res
Exemplo n.º 2
0
    def forward(self, xyz, neighbors, data_idxes, local_axises, cls_label):
        B, N, C = xyz.shape

        # seperate the neighbor information into layers
        neighbors_layer = []

        local_coordinates_layer = []
        data_idxes_layer = []
        parameters_layer = [0] * 5
        lc_std = torch.tensor(0.0, device=xyz.device)
        lc_consistence = torch.tensor(0.0, device=xyz.device)
        cid = 0
        current_xyz = xyz
        for idx, point_num in enumerate(self.point_num):
            # random choose neighbors
            if self.train:
                if self.random_nb:
                    random_index = torch.from_numpy(
                        np.concatenate([np.array([0]), np.random.choice(32, KNN_NUM, replace=False)], 0)).to(
                        xyz.device).long()
                    NBs = neighbors[:, cid:cid + point_num, random_index].squeeze()
                else:
                    NBs = neighbors[:, cid:cid + point_num, 0:KNN_NUM + 1].squeeze()
                if self.random_sp:
                    if idx == 0:
                        DIs = random_sample(2048, point_num, NBs.shape[0])
                    else:
                        DIs = random_sample(self.point_num[idx - 1], point_num, NBs.shape[0])
                else:
                    DIs = data_idxes[:, cid:cid + point_num].squeeze()
            else:
                NBs = neighbors[:, cid:cid + point_num, 0:KNN_NUM + 1].squeeze()
                DIs = data_idxes[:, cid:cid + point_num].squeeze()
            local_ax = local_axises[:, cid:cid + point_num, :, :]
            current_xyz = index_points(current_xyz, DIs)
            # x_axis, y_axis, z_axis = self.axis_net(current_xyz, NBs, local_ax)  # BNC
            # print(self.axis_net.fc1.weight)
            axis = local_ax  # torch.stack([x_axis, y_axis, z_axis]).permute(1, 2, 3, 0)  # BNC3
            grouped_xyz = index_points(current_xyz[:, :, 0:3], NBs)
            grouped_xyz = grouped_xyz - current_xyz[:, :, 0:3].unsqueeze(2)  # BNKC
            lc = torch.matmul(grouped_xyz, axis)  # BNKC
            local_coordinates_layer.append(lc)

            lc_consistence += lc_consistence_loss(axis, NBs, idx)
            # lc_std += lc_std_loss(lc)

            # local_coordinates_layer.append(local_coordinates[:, cid:cid+point_num,:, :].squeeze())
            neighbors_layer.append(NBs)
            data_idxes_layer.append(DIs)
            cid += point_num

        l2_xyz_local, l2_points_local = self.localnet(xyz, local_coordinates_layer, neighbors_layer, data_idxes_layer, parameters_layer)
        cls_label_one_hot = cls_label.view(B, 16, 1).repeat(1, 1, N).permute(0,2,1)  # BNC
        l2_points_local = torch.cat([l2_points_local, cls_label_one_hot], 2)
        # l2_xyz, l2_points = self.globalnet(xyz, local_coordinates_layer, neighbors_layer, data_idxes_layer, parameters_layer)
        # l2_points = torch.cat([l2_points, l2_points_local], dim=-1)
        # l4_xyz, l4_points = self.sa4(l2_xyz_local, l2_points_local)
        # x = l4_points.view(B, self.scale * 512)
        x = self.drop1(F.relu(self.fc1(l2_points_local)))
        x = self.drop2(F.relu(self.fc2(x)))
        x = self.fc3(x)
        x = F.log_softmax(x, -1)
        return x, lc_std, lc_consistence
Exemplo n.º 3
0
    def forward(self, xyz, neighbors, local_axis):
        # input : B N C, global coordinate
        if self.normal:
            normal = xyz[:, :, 3:]
            xyz = xyz[:, :, 0:3]

        B, N, _ = xyz.shape
        grouped_xyz = index_points(xyz, neighbors)  # .permute(0, 3, 2, 1)  # B N K C
        grouped_normal = index_points(normal, neighbors)
        if self.global_input:
            local_xyz = grouped_xyz
        else:
            local_xyz = grouped_xyz - xyz.unsqueeze(2)  # B N K C
        B, N, K, _ = grouped_xyz.shape
        local_xyz = torch.matmul(local_xyz, local_axis)
        grouped_normal = torch.matmul(grouped_normal, local_axis)
        # input_xyz = torch.cat([local_xyz, xyz.unsqueeze(2).repeat(1, 1, K, 1)], 3).permute(0, 3, 2, 1)  # BCKN

        # l1_points = self.sa1(input_xyz, None)
        # l2_points = self.sa2(input_xyz, l1_points)

        # l2_points = l2_points.permute(0, 3, 2, 1).reshape(B*N, K, -1)
        input_xyz = torch.cat([local_xyz, grouped_xyz, grouped_normal], 3).permute(0, 3, 2, 1).reshape(B*N, K, -1)
        _, l3_points = self.sa3(input_xyz)  # BCKN
        x = l3_points.reshape(B * N, -1)
        x = self.drop1(F.relu(self.fc1(x)))
        x = self.drop2(F.relu(self.fc2(x)))
        x = F.tanh(self.fc3(x))

        # orthogonalization and normalization
        if self.normal:
            x = x.reshape(B, N)*2*3.141592653
            rotation_matrix =torch.stack([torch.stack([torch.cos(x), torch.sin(-x)]),
                                          torch.stack([torch.sin(x), torch.cos(x)])]).permute(2,3,0,1)  # BN C 2
            x_axis = rotation_matrix[:,:,0,0].unsqueeze(2)*local_axis[:,:,:,0] + \
                     rotation_matrix[:,:,1,0].unsqueeze(2)*local_axis[:,:,:,1]
            #y_axis = rotation_matrix[:,:,0,1].unsqueeze(2)*local_axis[:,:,:,0] + \
            #         rotation_matrix[:,:,1,1].unsqueeze(2)*local_axis[:,:,:,1]

            # x = x + local_axis[:,:,:,0]
            # x_axis = x - torch.matmul(normal.unsqueeze(2), x.unsqueeze(3)).squeeze(2) * normal
            # x_axis = x_axis / (x_axis.norm(dim=2, keepdim=True) + 1e-9)

            # x_axis = torch.matmul(rotation_matrix, local_axis[:,:,:,0].unsqueeze(3)).squeeze()
            # x_axis = local_axis[:,:,:,0]
            y_axis = torch.cross(normal, x_axis)

            # q,w,e = x_axis.norm(dim=2), y_axis.norm(dim=2), normal.norm(dim=2)
            # r,t,y = torch.matmul(normal.unsqueeze(2), x_axis.unsqueeze(3)), torch.matmul(normal.unsqueeze(2), y_axis.unsqueeze(3))\
            #     ,torch.matmul(x_axis.unsqueeze(2), y_axis.unsqueeze(3))
            return [x_axis, y_axis, normal]
        else:
            alpha1 = x[:, 0:3]  # BN X 3
            alpha2 = x[:, 3:]
            alpha1_norm = alpha1.norm(dim=1) + 1e-9  # BN
            k = torch.matmul(alpha1.unsqueeze(1), alpha2.unsqueeze(2)).squeeze() / alpha1_norm ** 2  # BN
            beta2 = alpha2 - k.unsqueeze(1) * alpha1  # BN 3
            x_axis = beta2 / (beta2.norm(dim=1, keepdim=True) + 1e-9)  # BN 3
            z_axis = alpha1 / alpha1_norm.unsqueeze(1)  # BN 3
            y_axis = z_axis.cross(x_axis)  # BN 3
            # self.fc1.weight.retain_grad()
            return [x_axis.reshape(B, N, -1), y_axis.reshape(B, N, -1), z_axis.reshape(B, N, -1)]
Exemplo n.º 4
0
    def forward(self, xyz, neighbors, data_idxes, local_axises):
        B, N, C = xyz.shape

        # seperate the neighbor information into layers
        neighbors_layer = []

        local_coordinates_layer = []
        data_idxes_layer = []
        parameters_layer = [0] * 5
        lc_std = torch.tensor(0.0, device=xyz.device)
        lc_consistence = torch.tensor(0.0, device=xyz.device)
        cid = 0
        current_xyz = xyz
        lcaxis = [local_axises[:, cid:cid + self.point_num[0], :, :]]

        for idx, point_num in enumerate(self.point_num):
            # random choose neighbors

            if True:  # self.train:
                if self.random_sp:
                    if idx == 0:
                        DIs = random_sample(2048, point_num, B)
                        NBs = neighbors[:, cid:cid + point_num,
                                        0:KNN_NUM + 1].squeeze()
                    else:
                        DIs = random_sample(self.point_num[idx - 1], point_num,
                                            B)
                        # current_xyz = index_points(current_xyz, DIs)
                        NBs = knn_point(KNN_NUM + 1, current_xyz,
                                        current_xyz).squeeze()
                        NBs = index_points(NBs, DIs)
                else:
                    DIs = data_idxes[:, cid:cid + point_num].squeeze()
                    NBs = neighbors[:, cid:cid + point_num,
                                    0:KNN_NUM + 1].squeeze()
            # else:
            #     NBs = neighbors[:, cid:cid + point_num, 0:KNN_NUM + 1].squeeze()
            #     DIs = data_idxes[:, cid:cid + point_num].squeeze()
            pre_axis = lcaxis[-1].reshape(B, -1, 9)
            current_axis = index_points(pre_axis, DIs).reshape(B, -1, 3, 3)
            lcaxis.append(current_axis)

            local_ax = lcaxis[-1]
            # current_xyz = index_points(current_xyz, DIs)
            # x_axis, y_axis, z_axis = self.axis_net(current_xyz, NBs, local_ax)  # BNC
            # print(self.axis_net.fc1.weight)
            axis = local_ax  # torch.stack([x_axis, y_axis, z_axis]).permute(1, 2, 3, 0)  # BNC3
            grouped_xyz = index_points(current_xyz[:, :, 0:3], NBs)
            current_xyz = index_points(current_xyz, DIs)
            grouped_xyz = grouped_xyz - current_xyz[:, :, 0:3].unsqueeze(
                2)  # BNKC
            lc = torch.matmul(grouped_xyz, axis)  # BNKC
            local_coordinates_layer.append(lc)
            # if idx == 0:
            #     lc_consistence += lc_consistence_loss(lcaxis[-1], NBs, idx)
            # else:
            #     lc_consistence += lc_consistence_loss(lcaxis[-2], NBs, idx)
            # lc_std += lc_std_loss(lc)

            # local_coordinates_layer.append(local_coordinates[:, cid:cid+point_num,:, :].squeeze())
            neighbors_layer.append(NBs)
            data_idxes_layer.append(DIs)
            cid += point_num

        l2_xyz_local, l2_points_local = self.localnet(xyz,
                                                      local_coordinates_layer,
                                                      neighbors_layer,
                                                      data_idxes_layer,
                                                      parameters_layer)
        # l2_xyz, l2_points = self.globalnet(xyz, local_coordinates_layer, neighbors_layer, data_idxes_layer, parameters_layer)
        # l2_points = torch.cat([l2_points, l2_points_local], dim=-1)
        l4_xyz, l4_points = self.sa4(l2_xyz_local, l2_points_local)
        x = l4_points.view(B, self.scale * 512)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop1(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x = F.log_softmax(x, -1)
        return x, lc_std, lc_consistence