Ejemplo n.º 1
0
    def forward(self, x, y):
        # x: B*INPUT_FEATURE_NUM*sample_num_level1*knn_K, y: B*3*sample_num_level1*1
        x = self.netR_1(x)
        # B*128*sample_num_level1*1
        x = torch.cat((y, x), 1).squeeze(-1)
        # B*(3+128)*sample_num_level1

        inputs_level2, inputs_level2_center = group_points_2(
            x, self.sample_num_level1, self.sample_num_level2, self.knn_K,
            self.ball_radius2)
        # B*131*sample_num_level2*knn_K, B*3*sample_num_level2*1

        # B*131*sample_num_level2*knn_K
        x = self.netR_2(inputs_level2)
        # B*256*sample_num_level2*1
        x = torch.cat((inputs_level2_center, x), 1)
        # B*259*sample_num_level2*1

        x = self.netR_3(x)
        # B*1024*1*1
        x = x.view(-1, nstates_plus_3[2])
        # B*1024
        x = self.netR_FC(x)
        # B*num_outputs

        return x
Ejemplo n.º 2
0
    def forward(self, x, y):
        # x: B*INPUT_FEATURE_NUM*sample_num_level1*knn_K, y: B*3*sample_num_level1*1
        #print(x.shape)
        x = self.netR_1(x)
        # B*128*sample_num_level1*1
        y = y[:, :3, :, :]
        # print(x.shape,y.shape)
        x = torch.cat((y, x), 1).squeeze(-1)
        # B*(3+128)*sample_num_level1
        self.ball_radius2 = self.ball_radius2 - 0.1 + torch.randn(1) / 120.0
        inputs_level2, inputs_level2_center = group_points_2(
            x, self.sample_num_level1, self.sample_num_level2, self.knn_K,
            self.ball_radius2)
        # B*131*sample_num_level2*knn_K, B*3*sample_num_level2*1
        #print(inputs_level2.shape)
        # B*131*sample_num_level2*knn_K
        x = self.netR_2(inputs_level2)
        # B*256*sample_num_level2*1
        x = torch.cat((inputs_level2_center, x), 1)
        # B*259*sample_num_level2*1

        x = self.netR_3(x)
        # B*1024*1*1
        x = x.view(-1, nstates_plus_3[2])
        #print(x.shape)
        # B*1024
        x = self.netR_FC(x)
        # B*num_label
        return x
Ejemplo n.º 3
0
    def forward(self, xt, xs1, xs2, xs3, yt, ys1, ys2, ys3):
        # x: B*INPUT_FEATURE_NUM*sample_num_level1*knn_K, y: B*3*sample_num_level1*1
        B, d, N, k = xt.shape
        xt = self.netDI_1(x)
        # B*128*sample_num_level1*1
        xt = torch.cat((yt, xt), 1).squeeze(-1)
        # B*(3+128)*sample_num_level1
        self.ball_radius2 = self.ball_radius2 + torch.randn(1) / 120.0
        inputs_level2, inputs_level2_center = group_points_2(
            xt, self.sample_num_level1, self.sample_num_level2, self.knn_K,
            self.ball_radius2)
        # B*131*sample_num_level2*knn_K, B*3*sample_num_level2*1

        # B*131*sample_num_level2*knn_K
        xt = self.netDI_2(inputs_level2)
        # B*256*sample_num_level2*1
        xt = torch.cat((inputs_level2_center, xt), 1)
        # B*259*sample_num_level2*1

        xt = self.netDI_3(xt).squeeze(-1).squeeze(-1)

        xs = torch.cat((xs1.unsqueeze(1), xs2.unsqueeze(1), xs3.unsqueeze(1)),
                       1)
        B, c, d, N, k = xs.shape
        xs = xs.view(-1, d, N, k)
        xs = self.netR_C1(xs)
        ys = torch.cat((ys1.unsqueeze(1), ys2.unsqueeze(1), ys3.unsqueeze(1)),
                       1)
        ys = ys.view(-1, d, N, 1)
        xs = torch.cat((ys, xs), 1).squeeze(-1)
        inputs_level2_r, inputs_level2_center_r = group_points_2(
            xs, self.sample_num_level1, self.sample_num_level2, self.knn_K,
            self.ball_radius2)
        xs = self.netR_C2(inputs_level2_r)
        xs = torch.cat((inputs_level2_center_r, xs), 1)
        xs = self.netR_C3(xs).squeeze(-1).squeeze(-1)
        xs = xs.view(B, -1)
        x = torch.cat((xt, xs), -1)
        #print(x.shape)
        # B*1024
        if self.pooling == 'concatenation':
            #x = torch.cat((x1, x2),1)
            #x = x.view(-1,self.dim_out)
            #print(x.shape)
            x = self.netR_FC(x)
            # B*num_label
        return x
Ejemplo n.º 4
0
    def forward(self, xt, xs1, xs2, xs3, yt, ys1, ys2, ys3):

        # x: B*INPUT_FEATURE_NUM*sample_num_level1*knn_K, y: B*_*3*sample_num_level1*1

        ###----motion stream--------
        B, d, N, k = xt.shape
        #xt =xt.view(-1,d,N,k)
        xt = self.net3DV_1(xt)
        xt = torch.cat((yt, xt), 1).squeeze(-1)
        # B*(4+128)*sample_num_level1
        self.ball_radius2 = self.ball_radius2 + torch.randn(1) / 120.0
        inputs_level2, inputs_level2_center = group_points_2_3DV(
            xt, self.sample_num_level1, self.sample_num_level2, self.knn_K,
            self.ball_radius2)
        # # B*131*sample_num_level2*knn_K, B*3*sample_num_level2*1
        # # B*131*sample_num_level2*knn_K
        xt = self.net3DV_2(inputs_level2)
        # # B*256*sample_num_level2*1
        # print('netR_2:',x1.shape,x1[0,:,4])
        xt = torch.cat((inputs_level2_center, xt), 1)
        # # B*259*sample_num_level2*1
        xt = self.net3DV_3(xt).squeeze(-1).squeeze(-1)

        ###----apearance streams--------
        '''
        multiple streams shared one PointNet++
        '''
        xs = torch.cat((xs1.unsqueeze(1), xs2.unsqueeze(1), xs3.unsqueeze(1)),
                       1)
        B, c, d, N, k = xs.shape
        xs = xs.view(-1, d, N, k)

        xs = self.netR_C1(xs)
        ys = torch.cat((ys1.unsqueeze(1), ys2.unsqueeze(1), ys3.unsqueeze(1)),
                       1)

        ys = ys.view(-1, d, N, 1)
        xs = torch.cat((ys, xs), 1).squeeze(-1)
        # B*(3+128)*sample_num_level1
        inputs_level2_r, inputs_level2_center_r = group_points_2(
            xs, self.sample_num_level1, self.sample_num_level2, self.knn_K,
            self.ball_radius2)
        xs = self.netR_C2(inputs_level2_r)
        xs = torch.cat((inputs_level2_center_r, xs), 1)
        xs = self.netR_C3(xs).squeeze(-1).squeeze(-1)
        xs = xs.view(B, -1)
        x = torch.cat((xt, xs), -1)

        #print(x.shape)
        # if self.pooling == 'bilinear':
        #     x1 = self.dim_drop1(x1)
        #     x2 = self.dim_drop2(x2)
        #     x1 = x1.unsqueeze(1).expand(x1.size(0),x2.size(1),x1.size(-1))
        #     #print(x1.shape)
        #     x = x1*x2.unsqueeze(-1)
        #     x=x.view(-1,x1.size(-1)*x2.size(1))
        #     x = self.netR_FC(x)

        #print('x1:',x1.shape,'x2:',x2.shape)
        if self.pooling == 'concatenation':
            #x = torch.cat((x1, x2),1)
            #x = x.view(-1,self.dim_out)
            x = self.netR_FC(x)
            # B*num_label
        return x
Ejemplo n.º 5
0
    def forward(self, x, y):
        # x: B*INPUT_FEATURE_NUM*sample_num_level1*knn_K, y: B*3*sample_num_level1*1

        x = self.netR_1(x)
        # B*128*sample_num_level1*1
        level1 = torch.cat((y, x), 1).squeeze(-1)
        # B*(3+128)*sample_num_level1

        inputs_level2, inputs_level2_center = group_points_2(
            level1, self.sample_num_level1, self.sample_num_level2, self.knn_K,
            self.ball_radius2)
        # B*131*sample_num_level2*knn_K, B*3*sample_num_level2*1

        # B*131*sample_num_level2*knn_K
        x = self.netR_2(inputs_level2)
        # B*256*sample_num_level2*1
        x = torch.cat((inputs_level2_center, x), 1)
        # B*259*sample_num_level2*1

        del inputs_level2, inputs_level2_center

        code = self.netR_3(x)

        # del x
        # B*512*1*1

        skeleton = torch.from_numpy(
            np.array(
                [[
                    0,
                    -0.15,
                    -0.15,
                    -0.15,
                    -0.15,
                    0,
                    0,
                    0,
                    0,
                    0.175,
                    0.175,
                    0.175,
                    0.175,
                    0.3,
                    0.3,
                    0.3,
                    0.3,
                    -0.12,
                    -0.12,
                    -0.12,
                    -0.12,
                ],
                 [
                     0, 0.26, 0.4, 0.5, 0.6, 0.33, 0.49, 0.59, 0.69, 0.3, 0.45,
                     0.55, 0.65, 0.175, 0.275, 0.35, 0.425, 0.07, 0.2, 0.3, 0.4
                 ]],
                dtype=np.float32)).cuda()
        skeleton = skeleton.unsqueeze(0).unsqueeze(-2).expand(
            x.size(0), 2, 1, 21)

        code = code.expand(x.size(0), encoder_3[2], 1, 21)
        x = torch.cat((skeleton, code), 1)
        # B*(2+512)*1*21
        fold1_1 = self.netFolding1_1(x)
        fold1 = self.netFolding1_2(fold1_1)
        # B*3*1*21

        relative_idx1_cuda = torch.from_numpy(relative_idx1).cuda().unsqueeze(
            0).unsqueeze(0).expand(fold1_1.size(0), fold1_1.size(1), 21)
        relative_idx2_cuda = torch.from_numpy(relative_idx2).cuda().unsqueeze(
            0).unsqueeze(0).expand(fold1_1.size(0), fold1_1.size(1), 21)

        rel1 = torch.gather(fold1_1.squeeze(2), 2,
                            relative_idx1_cuda).unsqueeze(-2).expand(
                                fold1_1.size(0), fold1_1.size(1), 1, 21)
        rel2 = torch.gather(fold1_1.squeeze(2), 2,
                            relative_idx2_cuda).unsqueeze(-2).expand(
                                fold1_1.size(0), fold1_1.size(1), 1, 21)

        link1 = final_group(
            fold1.squeeze(2).transpose(1, 2),
            level1.squeeze(-1).transpose(1, 2), 64, 0.16).transpose(1, 3)

        x = torch.cat(
            (fold1.expand(x.size(0), 3, 64,
                          21), fold1_1.expand(x.size(0), folding_1[2], 64, 21),
             rel1.expand(x.size(0), folding_1[2], 64, 21),
             rel2.expand(x.size(0), folding_1[2], 64, 21), link1), 1)

        # del fold1_1, rel1, rel2, level1

        # B*(3+730+128)*64*21
        fold2_1 = self.netFolding2_1(x)
        # B*256*1*21

        x = self.netFolding2_2(fold2_1)
        # B*3*1*21

        fold2 = x + fold1

        relative_idx1_cuda = torch.from_numpy(relative_idx1).cuda().unsqueeze(
            0).unsqueeze(0).expand(fold2_1.size(0), fold2_1.size(1), 21)
        relative_idx2_cuda = torch.from_numpy(relative_idx2).cuda().unsqueeze(
            0).unsqueeze(0).expand(fold2_1.size(0), fold2_1.size(1), 21)

        rel2_1 = torch.gather(fold2_1.squeeze(2), 2,
                              relative_idx1_cuda).unsqueeze(-2).expand(
                                  fold2_1.size(0), fold2_1.size(1), 1, 21)
        rel2_2 = torch.gather(fold2_1.squeeze(2), 2,
                              relative_idx2_cuda).unsqueeze(-2).expand(
                                  fold2_1.size(0), fold2_1.size(1), 1, 21)

        link2 = final_group(
            fold2.squeeze(2).transpose(1, 2),
            level1.squeeze(-1).transpose(1, 2), 64, 0.16).transpose(1, 3)

        x = torch.cat(
            (fold2.expand(x.size(0), 3, 64,
                          21), fold2_1.expand(x.size(0), folding_2[2], 64, 21),
             rel2_1.expand(x.size(0), folding_2[2], 64, 21),
             rel2_2.expand(x.size(0), folding_2[2], 64, 21), link2), 1)

        fold3_1 = self.netFolding3_1(x)
        # B*256*1*21

        x = self.netFolding3_2(fold3_1)

        fold3 = x + fold2

        fold3 = fold3.transpose(1, 3).contiguous().view(-1, 63)
        fold2 = fold2.transpose(1, 3).contiguous().view(-1, 63)
        fold1 = fold1.transpose(1, 3).contiguous().view(-1, 63)

        # B*(21*3)
        # B*(21*3)

        return fold1, fold2, fold3