def forward(self, data): dense_input = True if isinstance(data, torch.Tensor) else False if dense_input: # Convert to torch_geometric.data.Data type data = data.transpose(1, 2).contiguous() batch_size, N, _ = data.shape # (batch_size, num_points, 3) pos = data.view(batch_size * N, -1) batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long) for i in range(batch_size): batch[i] = i batch = batch.view(-1) data = Data() data.pos, data.batch = pos, batch if not hasattr(data, "x"): data.x = None data_in = data.x, data.pos, data.batch sa1_out = self.sa1_module(data_in) sa2_out = self.sa2_module(sa1_out) sa3_out = self.sa3_module(sa2_out) x, pos, batch = sa3_out x = F.relu(self.lin1(x)) x = F.dropout(x, p=self.config["dropout"], training=self.training) x = F.relu(self.lin2(x)) x = F.dropout(x, p=self.config["dropout"], training=self.training) x = self.lin3(x) return x
def forward(self, data): """ data: a batch of input, torch.Tensor or torch_geometric.data.Data type - torch.Tensor: (batch_size, 3, num_points), as common batch input - torch_geometric.data.Data, as torch_geometric batch input: data.x: (batch_size * ~num_points, C), batch nodes/points feature, ~num_points means each sample can have different number of points/nodes data.pos: (batch_size * ~num_points, 3) data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud idendifiers for all nodes of all graphs/pointclouds in the batch. See pytorch_gemometric documentation for more information """ dense_input = True if isinstance(data, torch.Tensor) else False if dense_input: # Convert to torch_geometric.data.Data type data = data.transpose(1, 2).contiguous() batch_size, N, _ = data.shape # (batch_size, num_points, 3) pos = data.view(batch_size * N, -1) batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long) for i in range(batch_size): batch[i] = i batch = batch.view(-1) data = Data() data.pos, data.batch = pos, batch if not hasattr(data, "x"): data.x = None data_in = data.x, data.pos, data.batch sa1_out = self.sa1_module(data_in) sa2_out = self.sa2_module(sa1_out) sa3_out = self.sa3_module(sa2_out) fp3_out = self.fp3_module(sa3_out, sa2_out) fp2_out = self.fp2_module(fp3_out, sa1_out) fp1_out = self.fp1_module(fp2_out, data_in) fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out x = self.fc2(self.dropout1(self.fc1(fp1_out_x))) x = F.log_softmax(x, dim=-1) if dense_input: return x.view(batch_size, N, self.num_classes) return x, fp1_out_batch