Beispiel #1
0
    def forward(self, data):
        dense_input = True if isinstance(data, torch.Tensor) else False

        if dense_input:
            # Convert to torch_geometric.data.Data type
            data = data.transpose(1, 2).contiguous()
            batch_size, N, _ = data.shape  # (batch_size, num_points, 3)
            pos = data.view(batch_size * N, -1)
            batch = torch.zeros((batch_size, N),
                                device=pos.device,
                                dtype=torch.long)
            for i in range(batch_size):
                batch[i] = i
            batch = batch.view(-1)

            data = Data()
            data.pos, data.batch = pos, batch

        if not hasattr(data, "x"):
            data.x = None
        data_in = data.x, data.pos, data.batch
        sa1_out = self.sa1_module(data_in)
        sa2_out = self.sa2_module(sa1_out)
        sa3_out = self.sa3_module(sa2_out)
        x, pos, batch = sa3_out

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.config["dropout"], training=self.training)
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=self.config["dropout"], training=self.training)
        x = self.lin3(x)
        return x
Beispiel #2
0
    def forward(self, data):
        """
        data: a batch of input, torch.Tensor or torch_geometric.data.Data type
            - torch.Tensor: (batch_size, 3, num_points), as common batch input

            - torch_geometric.data.Data, as torch_geometric batch input:
                data.x: (batch_size * ~num_points, C), batch nodes/points feature,
                    ~num_points means each sample can have different number of points/nodes

                data.pos: (batch_size * ~num_points, 3)

                data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud
                    idendifiers for all nodes of all graphs/pointclouds in the batch. See
                    pytorch_gemometric documentation for more information
        """
        dense_input = True if isinstance(data, torch.Tensor) else False

        if dense_input:
            # Convert to torch_geometric.data.Data type
            data = data.transpose(1, 2).contiguous()
            batch_size, N, _ = data.shape  # (batch_size, num_points, 3)
            pos = data.view(batch_size * N, -1)
            batch = torch.zeros((batch_size, N),
                                device=pos.device,
                                dtype=torch.long)
            for i in range(batch_size):
                batch[i] = i
            batch = batch.view(-1)

            data = Data()
            data.pos, data.batch = pos, batch

        if not hasattr(data, "x"):
            data.x = None
        data_in = data.x, data.pos, data.batch

        sa1_out = self.sa1_module(data_in)
        sa2_out = self.sa2_module(sa1_out)
        sa3_out = self.sa3_module(sa2_out)

        fp3_out = self.fp3_module(sa3_out, sa2_out)
        fp2_out = self.fp2_module(fp3_out, sa1_out)
        fp1_out = self.fp1_module(fp2_out, data_in)

        fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out
        x = self.fc2(self.dropout1(self.fc1(fp1_out_x)))
        x = F.log_softmax(x, dim=-1)

        if dense_input:
            return x.view(batch_size, N, self.num_classes)
        return x, fp1_out_batch