def forward(self, pos, batch): x = pos.new_ones((pos.size(0), 1)) radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv1(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv2(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv3(x, edge_index, pseudo)) x = global_mean_pool(x, batch) x = F.elu(self.lin1(x)) x = F.elu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, pos, batch): radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv1(None, pos, edge_index)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv2(x, pos, edge_index)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv3(x, pos, edge_index)) x = global_max_pool(x, batch) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, pos, batch): with torch.no_grad(): cls_out = self.cls_model(pos, batch) #n*3 pred = cls_out.max(1)[1].reshape(-1,1) cls = torch.zeros(pred.size()[0], self.NUM_CLASS).to(device) cls.scatter_(1, pred, 1) x = F.relu(self.pcnn1(None, pos, batch)) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.pcnn2(x, pos, batch)) idx = fps(pos, batch, ratio=0.333) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.pcnn3(x, pos, batch)) x = global_mean_pool(x, batch) x = torch.cat((x, cls), dim = 1) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) # print("x.lin3",x.shape) return x
def forward(self, data): x, pos, batch = data.x, data.pos[:, :3], data.batch x = F.hardtanh(self.conv1(None, pos, batch)) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.hardtanh(self.conv2(x, pos, batch)) idx = fps(pos, batch, ratio=0.334) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.hardtanh(self.conv3(x, pos, batch)) x = F.hardtanh(self.conv4(x, pos, batch)) if self.pool == 'max': x = global_max_pool(x, batch) elif self.pool == 'mean': x = global_mean_pool(x, batch) x = F.hardtanh(self.lin1(x)) x = F.hardtanh(self.lin2(x)) x = self.lin3(x) return { 'out': F.log_softmax(x, dim=-1) }
def forward(self, pos, ctr, batch): with torch.no_grad(): cls_out = self.cls_model(pos, batch) #n*3 pred = cls_out.max(1)[1].reshape(-1,1) cls = torch.zeros(pred.size()[0], self.NUM_CLASS).to(device) cls.scatter_(1, pred, 1) x = F.relu(self.pcnn1(None, pos, batch)) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.pcnn2(x, pos, batch)) idx = fps(pos, batch, ratio=0.333) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.pcnn3(x, pos, batch)) x = global_mean_pool(x, batch) center_pt=ctr.view((torch.max(batch)+1,3)) c = F.relu(self.linctr(center_pt)) x = torch.cat((x, cls), dim = 1) x = torch.cat((x, c), dim = 1) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin3(x)) x1 = self.lin_psi(x) x2 = self.lin_theta(x) x3 = self.lin_phi(x) x=torch.cat((x1[:,0:12],x2[:,0:12],x3[:,0:12], x1[:,12:], x2[:,12:], x3[:,12:]),1) return x
def forward(self, data): pos, batch = data.pos, data.batch idx = fps(pos, batch, ratio=0.5) # 512 points row, col = radius(pos, pos[idx], 0.2, batch, batch[idx], max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) # Transpose. x = F.relu(self.local_sa1(None, (pos, pos[idx]), edge_index)) pos, batch = pos[idx], batch[idx] idx = fps(pos, batch, ratio=0.25) # 128 points row, col = radius(pos, pos[idx], 0.4, batch, batch[idx], max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) # Transpose. x = F.relu(self.local_sa2(x, (pos, pos[idx]), edge_index)) pos, batch = pos[idx], batch[idx] x = self.global_sa(torch.cat([x, pos], dim=1)) x = x.view(-1, 128, self.lin1.in_features).max(dim=1)[0] x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, pos, batch): x = F.relu(self.pcnn1(None, pos, batch)) # print("pcnn1",x.shape) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.pcnn2(x, pos, batch)) # print("pcnn2",x.shape) idx = fps(pos, batch, ratio=0.333) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.pcnn3(x, pos, batch)) # print("pcnn3",x.shape) # idx = fps(pos, batch, ratio=0.5) # x, pos, batch = x[idx], pos[idx], batch[idx] # x = F.relu(self.pcnn4(x, pos, batch)) # # print("pcnn4",x.shape) # idx = fps(pos, batch, ratio=0.5) # x, pos, batch = x[idx], pos[idx], batch[idx] # x = F.relu(self.pcnn5(x, pos, batch)) # print("pcnn5",x.shape) x = global_mean_pool(x, batch) # print("global_mean_pool",x.shape) self.feature = x x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) # print("x.lin3",x.shape) return F.log_softmax(x, dim=-1)
def forward(self, pos, batch): radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[0](None, pos, edge_index)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[1](x, pos, edge_index)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[2](x, pos, edge_index)) x = global_max_pool(x, batch) feat = x x = F.relu(self.classifier[0](x)) x = F.relu(self.classifier[1](x)) x = F.dropout(x, p=0.5, training=self.training) x = self.classifier[2](x) x2 = F.relu(self.discriminator[0](feat)) x2 = F.dropout(x2, p=0.5, training=self.training) x2 = self.discriminator[1](x2) return F.log_softmax(x, dim=-1), F.log_softmax(x2, dim=-1)
def forward(self, data): pos, batch = data.pos, data.batch idx = fps(pos, batch, ratio=0.5) # 512 points row, col = radius(pos[idx], pos, 0.1, batch[idx], batch, max_num_neighbors=64) edge_index = torch.stack([row, idx[col]], dim=0) x = F.relu(self.local_sa1(None, pos, edge_index)) x, pos, batch = x[idx], pos[idx], batch[idx] idx = fps(pos, batch, ratio=0.25) # 128 points row, col = radius(pos[idx], pos, 0.2, batch[idx], batch, max_num_neighbors=64) edge_index = torch.stack([row, idx[col]], dim=0) x = F.relu(self.local_sa2(x, pos, edge_index)) x, pos, batch = x[idx], pos[idx], batch[idx] x = self.global_sa(torch.cat([x, pos], dim=1)) x = x.view(-1, 128, 1024).max(dim=1)[0] x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def sample(self, pos, batch, **kwargs): from torch_geometric.nn import fps if len(pos.shape) != 2: raise ValueError( " This class is for sparse data and expects the pos tensor to be of dimension 2" ) return fps(pos, batch, ratio=self._get_ratio_to_sample(pos.shape[0]))
def forward(self, points, batch): ratio = 1/self.nb_neighbors fps_indices = gnn.fps( x=points, batch=batch, ratio=ratio ) fps_points = points[fps_indices] fps_batch = batch[fps_indices] radius_cluster, radius_indices = gnn.radius( x=points, y=fps_points, batch_x=batch, batch_y=fps_batch, r=self.radius ) anchor_points = fps_points[radius_cluster] radius_points = points[radius_indices] relative_points = (radius_points - anchor_points) / self.radius features = self.neighborhood_encoder(relative_points, radius_cluster) return fps_points, features, fps_batch
def forward(self, points, batch): ratio = 1 / self.nb_neighbors fps_indices = gnn.fps(x=points, batch=batch, ratio=ratio) fps_points = points[fps_indices] fps_batch = batch[fps_indices] radius_cluster, radius_indices = gnn.radius(x=points, y=fps_points, batch_x=batch, batch_y=fps_batch, r=self.radius) anchor_points = fps_points[radius_cluster] radius_points = points[radius_indices] relative_points = (radius_points - anchor_points) / self.radius fc1_features = F.relu(self.fc1(relative_points)) fc2_features = F.relu(self.fc2(fc1_features)) fc3_features = F.relu(self.fc3(fc2_features)) max_features = gnn.global_max_pool(x=fc3_features, batch=radius_cluster) fc1_global_features = F.relu(self.fc1_global(max_features)) fc2_global_features = F.relu(self.fc2_global(fc1_global_features)) fc3_global_features = F.relu(self.fc3_global(fc2_global_features)) return fps_points, fc3_global_features, fps_batch
def forward(self, x, pos, batch): idx = fps(pos, batch, ratio=self.ratio) row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] return x, pos, batch
def forward(self, x, pos, batch, norm=None): # pool points based on FPS algorithm, returning Npt*ratio centroids idx = fps(pos, batch, ratio=self.ratio) # finds points within radius `self.r` of the centroids, up to `self.K` pts per centroid row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=self.K) # edges joining centroids to their neighbors within ball of radius `self.r` edge_index = torch.stack([col, row], dim=0) # perform convolution if self.conv_name == 'PointConv': x = self.conv(x, (pos, pos[idx]), edge_index) elif self.conv_name == 'GraphConv': x = self.conv(x, edge_index)[idx] elif self.conv_name == 'PPFConv': x = self.conv(x, pos, norm, edge_index)[idx] pos, batch = pos[idx], batch[idx] return (x, pos, batch), idx
def subsample_fps(self, n_vert): assert n_vert <= self.vert.shape[ 0], "you can only subsample to less vertices than before" ratio = n_vert / self.vert.shape[0] self.samples = fps(self.vert.detach().to(device_cpu), ratio=ratio).to(device)
def downsample(self, data, with_features): if self.remove_zeros: mask = data[self.coordinates_key].norm(dim=-2) > 0.0001 data[self.coordinates_key] = data[self.coordinates_key][...,mask] if with_features: data['features'] = data['features'][:,mask] if "time_stamps" in data.keys(): data['time_stamps'] = data['time_stamps'][:,mask] coords = data[self.coordinates_key].to(self.device) #pcds = coords.view(coords.shape[0] * coords.shape[1], *coords.shape[2:]).permute(0,2,1) b = convert_data_to_batch(coords.permute(1,0).unsqueeze(dim=0)) ratio = float(self.num_points+1) / coords.shape[-1] inds = fps(b.pos, batch=b.batch.to(b.pos.device), ratio=ratio) if inds.shape[0]>self.num_points: inds = inds[:self.num_points] data[self.coordinates_key] = data[self.coordinates_key][:,inds.to(data[self.coordinates_key].device)] if with_features: data['features'] = data['features'][:,inds.to(data['features'].device)] if "time_stamps" in data.keys(): data['time_stamps'] = data['time_stamps'][:,inds.to(data['features'].device)] data['ds_inds'] = inds return data
def fps_pooling(pos, x, edge_attr, batch=None, k=16, r=0.5, reduce='sum'): assert reduce in ['max', 'mean', 'add', 'sum'] idx = fps(pos, batch, ratio=r) i, j = knn(pos, pos[idx], k, batch, batch[idx]) x = scatter(x[j], i, dim=0, reduce=reduce) pos, edge_attr, batch = pos[idx], edge_attr[idx], batch[idx] return x, pos, edge_attr, batch
def forward(self, points, features, batch): ratio = 1 / self.nb_neighbors fps_indices = gnn.fps(x=points, batch=batch, ratio=ratio) fps_points = points[fps_indices] fps_batch = batch[fps_indices] radius_cluster, radius_indices = gnn.radius(x=points, y=fps_points, batch_x=batch, batch_y=fps_batch, r=self.radius) anchor_points = fps_points[radius_cluster] radius_points = points[radius_indices] radius_features = features[radius_indices] relative_points = (radius_points - anchor_points) / self.radius rel_encoded = self.neighborhood_enc(relative_points, radius_cluster) rel_enc_mapped = rel_encoded[radius_cluster] fc_input = torch.cat( [relative_points, rel_enc_mapped, radius_features], dim=1) fc1_features = F.relu(self.fc1(fc_input)) max_features = gnn.global_max_pool(x=fc1_features, batch=radius_cluster) fc1_global_features = F.relu(self.fc1_global(max_features)) output_features = torch.cat([rel_encoded, fc1_global_features], dim=1) return fps_points, output_features, fps_batch
def forward(self, data): pos, edge_index, batch = data.pos, data.edge_index, data.batch # Build first edges edge_index = knn_graph(pos, self.k, batch, loop=False) #extract features in 3d _, _, features_3d = self.dsc3d(pos, edge_index) features_3d = torch.sigmoid(features_3d) _, _, features_dd = self.dd(pos, edge_index, features_3d) features_dd = torch.sigmoid(features_dd) # pooling 80% index = fps(pos, batch=batch, ratio=0.2) pos = pos[index] features = features_dd[index] batch = batch[index] edge_index = knn_graph( pos, self.k, batch, loop=False) #change pos to features for test later! # extract features in 3d again _, _, features_dd2 = self.dd2(pos, edge_index, features_dd) features_dd2 = torch.sigmoid(features_dd2) ys = features_dd2.view(self.batch_size, -1, self.out_size_2) ys = ys.mean(dim=1).view(-1, self.out_size_2) y1 = self.nn1(ys) y1 = F.elu(y1) y2 = self.nn2(y1) y2 = self.sm(y2) return y2
def forward(self, data): # input = torch.cat([data.norm, data.pos], dim=1) # i = torch.cat([data.norm, data.pos, data.x], dim=1) input = torch.cat([data.norm, data.pos], dim=1) x, batch = input, data.batch edge_index, edge_weight = data.edge_index, data.edge_attr edge_weight = torch.ones((edge_index.size(1),), dtype=x.dtype, device=edge_index.device) # first conv with full points x = F.dropout(x, training=self.training, p=0.2) x = F.relu(self.conv1(x, edge_index, edge_weight)) # Second conv with full points x = torch.cat([x, input], dim=1) x = F.dropout(x, training=self.training, p=0.2) x = F.relu(self.conv2(x, edge_index, edge_weight)) # first down sampling index generation idx = fps(data.pos, batch, ratio=0.5) row, col = radius(data.pos, data.pos[idx], 0.4, batch, batch[idx], max_num_neighbors=64) edge_index_int = torch.stack([col, row], dim=0) x = self.con_int(x, (data.pos, data.pos[idx]), edge_index_int) batch = batch[idx] edge_index, edge_weight = self.filter_adj(edge_index, edge_weight, idx, data.pos.size(0)) x = torch.cat([x, input[idx]], dim=1) x = F.relu(self.conv3(x, edge_index, edge_weight)) out, critical_points = global_max_pool(x, batch) out = self.lin1(out) out = F.log_softmax(out, dim=1) return out, critical_points
def get_voxels(self, cloud, context_cloud, vox_center): voxel_mask_1 = get_voxel(cloud, vox_center, self.final_voxel_size, return_mask=True) voxel_1 = cloud[voxel_mask_1] voxel_center_0 = vox_center voxel_0 = get_voxel(context_cloud, voxel_center_0, self.context_voxel_size) if voxel_1.shape[0] == 0: voxel_1 = voxel_0.mean(dim=-0).unsqueeze(0) print('Empty voxel,placing dummy point') else: voxel_1 = voxel_1[fps(voxel_1, torch.zeros(voxel_1.shape[0]).long(), ratio=self.n_samples / voxel_1.shape[0], random_start=False), :] voxel_1 = voxel_1[:self.n_samples, :] voxel_1_1 = get_voxel(cloud, voxel_center_0, self.context_voxel_size) if voxel_1_1.shape[0] == 0: voxel_1_1 = voxel_1_1.mean(dim=-0).unsqueeze(0) print('Empty voxel,placing dummy point') else: voxel_1_1 = voxel_1_1[fps(voxel_1_1, torch.zeros(voxel_1_1.shape[0]).long(), ratio=self.n_samples_context / voxel_1_1.shape[0], random_start=False), :] if voxel_0.shape[0] == 0: voxel_0 = voxel_1.mean(dim=-0).unsqueeze(0) print('Empty contenxt,placing dummy point') else: voxel_0 = voxel_0[fps(voxel_0, torch.zeros(voxel_0.shape[0]).long(), ratio=self.n_samples_context / voxel_0.shape[0], random_start=False), :] voxel_0 = voxel_0[:self.n_samples_context, :] return voxel_0, voxel_1, voxel_1_1
def forward(self, x, pos, batch): if (self.scale_factor < 1): downsampled_idx = fps(pos, batch, self.scale_factor) x = None if x is None else x[downsampled_idx] pos = pos[downsampled_idx] batch = batch[downsampled_idx] return x, pos, batch
def forward(self, x, pos, batch): idx = fps(pos, batch, ratio=self.ratio) row, col = radius( pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64 ) # TODO: FIGURE OUT THIS WITH RESPECT TO NUMBER OF POINTS edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] return x, pos, batch
def forward(self, data): x, pos, batch = data idx = fps(pos, batch, ratio=self.ratio) row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors) edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] data = (x, pos, batch) return data
def forward(self, pos, batch): x = F.relu(self.conv1(None, pos, batch)) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.conv2(x, pos, batch)) idx = fps(pos, batch, ratio=0.334) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.relu(self.conv3(x, pos, batch)) x = F.relu(self.conv4(x, pos, batch)) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def pointset_diameter(v, sample_times=100): r""" Calc. diamter of point cloud """ n_pts, fin = v.shape eps = 1e-6 diameter = -1.0 for _ in range(sample_times): index = tgnn.fps(v, ratio=2 / n_pts + eps) distance = (v[index][0] - v[index][1]).norm() diameter = max(distance, diameter) return diameter
def forward(self, data): pos, batch = data.pos, data.batch idx = fps(pos, batch, ratio=0.5) # 512 points edge_index = radius(pos[idx], pos, 0.1, batch[idx], batch, 48) x = F.relu(self.local_sa1(None, pos, edge_index)) pos, batch = pos[idx], batch[idx] idx = fps(pos, batch, ratio=0.25) # 128 points edge_index = radius(pos[idx], pos, 0.2, batch[idx], batch, 48) x = F.relu(self.local_sa2(x, pos, edge_index)) pos, batch = pos[idx], batch[idx] x = self.global_sa(torch.cat([x, pos], dim=1)) x = x.view(-1, 128, 1024).max(dim=1)[0] x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, x, pos, norm, batch): idx = fps(pos, batch, ratio=self.ratio) #可以用radius或nerest构建半径内或者最近邻图 row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=32) edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos[idx]), (norm, norm[idx]), edge_index) pos, norm, batch = pos[idx], norm[idx], batch[idx] return x, pos, norm, batch
def test_simple(self): num_points = 2048 pos = torch.randn((num_points, 3)).cuda() batch = torch.zeros((num_points)).cuda().long() idx = fps(pos, batch, 0.25) idx = idx.detach().cpu().numpy() cnd_1 = np.sum(idx) > 0 cnd_2 = np.sum(idx) < num_points * idx.shape[0] assert ( cnd_1 and cnd_2 ), "Your Pytorch Cluster FPS doesn't seem to return the correct value. It shouldn't be used to perform sampling"
def forward(self, pos, batch): x = F.leaky_relu(self.conv1(None, pos, batch), negative_slope=0.2) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.leaky_relu(self.conv2(x, pos, batch), negative_slope=0.2) idx = fps(pos, batch, ratio=0.334) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.leaky_relu(self.conv3(x, pos, batch), negative_slope=0.2) x = F.leaky_relu(self.conv4(x, pos, batch), negative_slope=0.2) # x1 = global_max_pool(x, batch) x = global_mean_pool(x, batch) # x = torch.cat([x1, x2], dim=1) x = F.leaky_relu(self.lin1(x), negative_slope=0.2) x = F.leaky_relu(self.lin2(x), negative_slope=0.2) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)