Esempio n. 1
0
    def forward(self, x):
        batch_size, _, num_points = x.size()

        x = get_graph_feature(x, self.k, extra_dim=True)
        x = self.conv1(x)
        x = self.conv2(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x1, self.k)
        x = self.conv3(x)
        x = self.conv4(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x2, self.k)
        x = self.conv5(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3), dim=1)

        x = self.conv6(x)
        x = x.max(dim=-1, keepdim=True)[0]

        x = x.repeat(1, 1, num_points)
        x = torch.cat((x, x1, x2, x3), dim=1)

        x = self.conv7(x)
        x = self.conv8(x)
        x = self.dp1(x)
        x = self.conv9(x)
        x = x.view(batch_size, num_points, self.num_class)
        return x
Esempio n. 2
0
    def forward(self, x):
        # pdb.set_trace()
        batch_size = x.size()[0]
        x = get_graph_feature(x, k=self.args.k)
        x = self.conv1(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x1, k=self.args.k)
        x = self.conv2(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x2, k=self.args.k)
        x = self.conv3(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x3, k=self.args.k)
        x = self.conv4(x)
        x4 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3, x4), dim=1)

        x = self.conv5(x)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
        x = torch.cat((x1, x2), 1)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
        x = self.dp2(x)
        x = self.linear3(x)
        return x
Esempio n. 3
0
    def forward(self, x):

        batch_size = x.size()[0]
        x = get_graph_feature(x, k=self.args.k)
        x = self.conv1(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x1, k=self.args.k)
        x = self.conv2(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x2, k=self.args.k)
        x = self.conv3(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x3, k=self.args.k)
        x = self.conv4(x)
        x4 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3, x4), dim=1)

        x = self.conv5(x)
        feature = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        # x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        # x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
        # feature = torch.cat((x1, x2), 1)

        coarse = self.folding1(feature)
        coarse = coarse.view(-1, self.num_coarse, 3)

        grid = self.build_grid(x.size()[0])
        # grid_feat = self.tile(grid, [1, self.num_coarse, 1])
        grid_feat = grid.repeat(1, self.num_coarse, 1)

        point_feat = self.tile(self.expand_dims(coarse, 2),
                               [1, 1, self.grid_size**2, 1])
        point_feat = point_feat.view([-1, self.num_fine, 3])

        global_feat = self.tile(self.expand_dims(feature, 1),
                                [1, self.num_fine, 1])
        feat = torch.cat([grid_feat, point_feat, global_feat], dim=2)

        center = self.tile(self.expand_dims(coarse, 2),
                           [1, 1, self.grid_size**2, 1])
        center = center.view([-1, self.num_fine, 3])

        fine = self.folding2(feat.transpose(2, 1)).transpose(2, 1) + center

        return coarse, fine
Esempio n. 4
0
    def forward(self, x, l):
        B, D, N = x.size()

        x0 = get_graph_feature(x, k=self.k)
        t = self.transform_net(x0)
        x = x.transpose(2, 1)
        if D > 3:
            x, feature = x.split(3, dim=2)
        x = torch.bmm(x, t)
        if D > 3:
            x = torch.cat([x, feature], dim=2)
        x = x.transpose(2, 1)

        x = get_graph_feature(x, k=self.k)
        x = self.conv1(x)
        x = self.conv2(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x1, k=self.k)
        x = self.conv3(x)
        x = self.conv4(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x2, k=self.k)
        x = self.conv5(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3), dim=1)

        x = self.conv6(x)
        x = x.max(dim=-1, keepdim=True)[0]

        l = l.view(B, -1, 1)
        l = self.conv7(l)

        x = torch.cat((x, l), dim=1)
        x = x.repeat(1, 1, N)

        x = torch.cat((x, x1, x2, x3), dim=1)

        x = self.conv8(x)
        x = self.dp1(x)
        x = self.conv9(x)
        x = self.dp2(x)
        x = self.conv10(x)
        x = self.conv11(x)

        return x.permute(0, 2, 1).contiguous()
Esempio n. 5
0
	def forward(self, x):

		batch_size, _, num_points = x.size()

		x = get_graph_feature(x, self.k)
		x = self.conv1(x)
		x = self.conv2(x)
		x1 = x.max(dim=-1, keepdim=False)[0]

		x = get_graph_feature(x1, k=self.k)
		x = self.conv3(x)
		x = self.conv4(x)
		x2 = x.max(dim=-1, keepdim=False)[0]

		x = get_graph_feature(x2, k=self.k)
		x = self.conv5(x)
		x3 = x.max(dim=-1, keepdim=False)[0]

		x = torch.cat((x1, x2, x3), dim=1)      
		
		x = self.conv6(x)
		x = x.max(dim=-1, keepdim=True)[0]

		x = x.repeat(1, 1, num_points)          
		x = torch.cat((x, x1, x2, x3), dim=1)   

		x = self.conv7(x)
		x = self.conv8(x)
		x = self.dp1(x)
		x = self.conv9(x)                       
		# x = F.softmax(x, dim=1)
		# x = F.log_softmax(x, dim=1)
		'''add softmax: 
			https://towardsdatascience.com/cuda-error-device-side-assert-triggered-c6ae1c8fa4c3
			https://github.com/pytorch/pytorch/issues/1204
		'''
		return x