def forward(self, pos, batch): radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[0](None, pos, edge_index)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[1](x, pos, edge_index)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[2](x, pos, edge_index)) x = global_max_pool(x, batch) feat = x x = F.relu(self.classifier[0](x)) x = F.relu(self.classifier[1](x)) x = F.dropout(x, p=0.5, training=self.training) x = self.classifier[2](x) x2 = F.relu(self.discriminator[0](feat)) x2 = F.dropout(x2, p=0.5, training=self.training) x2 = self.discriminator[1](x2) return F.log_softmax(x, dim=-1), F.log_softmax(x2, dim=-1)
def forward(self, pos, batch): radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv1(None, pos, edge_index)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv2(x, pos, edge_index)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv3(x, pos, edge_index)) x = global_max_pool(x, batch) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, pos, batch): x = pos.new_ones((pos.size(0), 1)) radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv1(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv2(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv3(x, edge_index, pseudo)) x = global_mean_pool(x, batch) x = F.elu(self.lin1(x)) x = F.elu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def get(self, idx): adj_all = torch.zeros((self.max_num_nodes, self.max_num_nodes), dtype=torch.float) # transform if self.dynamic_graph: data = torch.load(osp.join(self.processed_root, self.idxlist[idx])) num_nodes = data.num_nodes dist_path = os.path.join(self.root, 'proto', 'distance', self.setting.dataset, self.idxlist[idx]) choice, sample_num_node = self._sampling(num_nodes, self.sampling_ratio, dist_path) for key, item in data: if torch.is_tensor(item) and item.size(0) == num_nodes: data[key] = item[choice] # generate the graph if self.graph_sampler == 'knn': edge_index = radius_graph(data.pos, self.max_edge_distance, None, True, self.max_neighbours) adj = sparse_to_dense(edge_index) else: raise NotImplementedError adj[adj > 0] = 1 adj_all[:sample_num_node, :sample_num_node] = adj else: data = torch.load( osp.join(self.processed_fix_data_root, str(self.epoch), self.idxlist[idx])) if self.graph_sampler == 'knn': edge_index = radius_graph(data.pos, self.max_edge_distance, None, True, self.max_neighbours) adj = sparse_to_dense(edge_index) else: raise NotImplementedError num_nodes = adj.shape[0] adj_all[:num_nodes, :num_nodes] = adj feature = data.x if self.feature_type == 'c': feature = feature[:, -2:] elif self.feature_type == 'a': feature = feature[:, :-2] num_feature = feature.shape[1] feature = (feature - self.mean) / self.std feature_all = torch.zeros((self.max_num_nodes, num_feature), dtype=torch.float) if self.dynamic_graph: feature_all[:sample_num_node] = feature else: feature_all[:num_nodes] = feature label = data.y idx = torch.tensor(idx) return { 'adj': adj_all, 'feats': feature_all, 'label': label, 'num_nodes': sample_num_node if self.dynamic_graph else num_nodes, 'patch_idx': idx }
def forward(self, batch_data): z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch if self.energy_and_force: pos.requires_grad_() edge_index = radius_graph(pos, r=self.cutoff, batch=batch) num_nodes = z.size(0) dist, angle, torsion, i, j, idx_kj, idx_ji = xyztodat(pos, edge_index, num_nodes, use_torsion=True) emb = self.emb(dist, angle, torsion, idx_kj) #Initialize edge, node, graph features e = self.init_e(z, emb, i, j) v = self.init_v(e, i, num_nodes=pos.size(0)) u = self.init_u(torch.zeros_like(scatter(v, batch, dim=0)), v, batch) #scatter(v, batch, dim=0) for update_e, update_v, update_u in zip(self.update_es, self.update_vs, self.update_us): e = update_e(e, emb, idx_kj, idx_ji) v = update_v(e, i) u = update_u(u, v, batch) #u += scatter(v, batch, dim=0) return u
def forward(self, pos, p, f=None, batch=None): N = pos.shape[0] # build graph # col, row = knn_graph(pos, k=self.kernel_size, batch=batch) col, row = radius_graph(pos, r=self.radius, batch=batch, max_num_neighbors=self.kernel_size) u = -torch.log(p) # unary # compute weights f = f.unsqueeze(0).repeat(self.num_kernels, 1, 1) f = torch.bmm(f, self.F) # [K, N, H] f = f.permute((1, 0, 2)) # [N, K, H] f = f[col] - f[row] # [E, K, H] w = torch.exp(-torch.sum(f**2, dim=-1)) # [E, K] w = torch.mm(w, self.W) # [E, 1] # mean field steps q = p # initialize q [N, L] for _ in range(self.steps): q = scatter_add(q[col] * w, row, dim=0, dim_size=N) # message passing [N, L] q = torch.mm(q, self.C) # compatibility transformation [N, L] q = torch.softmax(-u - q, dim=-1) return q
def forward(self, data, phase): # pdb.set_trace() x, edge_index, pseudo, batch = data.x, data.edge_index, data.pos, data.batch # radius = np.random.randint(1,4) if phase == 'train' else 4 radius = 4 edge_index = radius_graph(pseudo.float(), r=radius, batch=batch) x = self.lin01(x) x_g1 = F.relu(self.conv1(x, edge_index)) x = self.graph_norm_1(x_g1, batch) x = F.dropout(x, p=0.8, training=phase == 'train') x_g2 = F.relu(self.conv2(x, edge_index)) x = self.graph_norm_2(x_g2, batch) x = F.dropout(x, p=0.8, training=phase == 'train') x_g3 = F.relu(self.conv3(x, edge_index)) x = self.graph_norm_3(x_g3, batch) # pdb.set_trace() x = self.pool(x, batch) x = self.lin2(x) x = self.lin3(x) # pdb.set_trace() return {'A2': x}
def forward(self, data): # Get node features if self.embedding.device != data.atomic_numbers.device: self.embedding = self.embedding.to(data.atomic_numbers.device) data.x = self.embedding[data.atomic_numbers.long() - 1] # Construct graph, get edge features data.edge_index = radius_graph(data.pos, r=self.cutoff, batch=data.batch) row, col = data.edge_index pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) data.edge_weight = (pos[row] - pos[col]).norm(dim=-1) data.edge_attr = self.distance_expansion(data.edge_weight) # Forward pass through the network mol_feats = self._convolve(data) mol_feats = self.conv_to_fc(mol_feats) if hasattr(self, "fcs"): mol_feats = self.fcs(mol_feats) energy = self.fc_out(mol_feats) if self.regress_forces: forces = -1 * (torch.autograd.grad( energy, pos, grad_outputs=torch.ones_like(energy), create_graph=True, )[0]) return energy, forces else: return energy
def __call__(self, sample): if self._visualization: keys = sorted([x for x in dir(sample) if 'edge_index' == x]) else: keys = sorted([ x for x in dir(sample) if 'edge_index' in x and 'dilated' not in x ]) pos_keys = sorted([x for x in dir(sample) if 'pos' in x]) for level, key in enumerate(keys): radius_edges = radius_graph(sample[pos_keys[level]], self._radius[level], max_num_neighbors=self._max_neigh) if not self._override: sample[key.replace('edge_index', 'euclidean_edge_index')] = radius_edges else: # in a single conv architecture, we only have to keep track of one edge set sample[key] = radius_edges if not self._keep_pos: # we do not need to now actual spatial positions of later layers -> delete them keys = sorted([x for x in dir(sample) if 'pos_' in x]) for key in keys: delattr(sample, key) return sample
def get(self, idx): # only support batch size = 1 data = torch.load(osp.join(self.processed_root, self.idxlist[idx])) # generate the graph if self.graph_sampler == 'knn': edge_index = radius_graph(data.pos, self.max_edge_distance, None, True, self.max_neighbours) adj_all = sparse_to_dense(edge_index) else: raise NotImplementedError adj_all[adj_all > 0] = 1 num_nodes = adj_all.shape[0] feature = data.x if self.feature_type == 'c': feature = feature[:, -2:] elif self.feature_type == 'a': feature = feature[:, :-2] feature = (feature - self.mean) / self.std label = data.y idx = torch.tensor([idx]) return { 'adj': adj_all, 'feats': feature, 'label': label, 'num_nodes': num_nodes if self.dynamic_graph else num_nodes, 'patch_idx': idx }
def forward(self, atomic_ns, coords, batch_node_vec): # processing sequence # dist -> rbf # dist+angle -> sbf # # rbf -> emb # rbf + sbf -> interaction # # sum(emb, interactions) -> output mlp # functions needed # interatomic dist # neighbour angles from triplets # dist to rbf edge_index = radius_graph(coords, self.cutoff_val, batch_node_vec) num_nodes = atomic_ns.size(0) dists, angles, node_is, node_js, kj, ji = xyz_to_dg( coords, edge_index, num_nodes) dist_embs, angle_embs = self.emb(dists, angles, kj) # init edge, node, graph feats #edge_attr = self. return
def forward(self, z, pos, batch=None): """""" edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( edge_index, num_nodes=z.size(0)) # Calculate distances. dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() # Calculate angles. pos_i = pos[idx_i] pos_ji, pos_ki = pos[idx_j] - pos_i, pos[idx_k] - pos_i a = (pos_ji * pos_ki).sum(dim=-1) b = torch.cross(pos_ji, pos_ki).norm(dim=-1) angle = torch.atan2(b, a) rbf = self.rbf(dist) sbf = self.sbf(dist, angle, idx_kj) # Embedding block. x = self.emb(z, rbf, i, j) P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) # Interaction blocks. for interaction_block, output_block in zip(self.interaction_blocks, self.output_blocks[1:]): x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) P += output_block(x, rbf, i) return P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)
def gen_graph( node_feature_path: str, node_feature_cols: List[str], label: int, max_neighbours: int = 8, radius: float = 50, ) -> Data: """ Generates graph from node features as pytorch object Arguments: node_feature_path {str} -- Path to node features for tile node_feature_cols {List[str]} -- list of features to use label {int} -- 1, or 0 for cancer/no-cancer Keyword Arguments: max_neighbours {int} -- k-max neighbors for graph edge (default: {8}) radius {float} -- max pixel radius for graph edge (default: {50}) Returns: Data object for use with pytorch geometric graph convolution """ features = pd.read_csv( node_feature_path, converters={ "diagnostics_Mask-original_CenterOfMass": ast.literal_eval } ) # Normalize Features f_df = features[node_feature_cols] f_norm_df = (f_df - f_df.mean()) / (f_df.max() - f_df.min()) f_norm_nafilled = f_norm_df.fillna(0) # Set centroid coordinates from node features coordinates = torch.tensor( features['diagnostics_Mask-original_CenterOfMass'].tolist() ) # Set node features from normalized features node_features = torch.tensor( f_norm_nafilled.astype(float).values, dtype=torch.float32 ) # Initialize data for pytorch y = torch.tensor([label], dtype=torch.long) data = Data( x=node_features, pos=coordinates, y=y, num_classes=2 ) # Create Graph data.edge_index = radius_graph( data.pos, radius, None, True, max_neighbours ) return data
def pt_to_gexf(numpyfile, savepath, sample=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1]): # data = torch.load(torchfile) # feature = data.x.numpy() coordinates = np.load(numpyfile) coordinates = coordinates / 1500. coordinates = torch.from_numpy(coordinates).to(torch.float) num_nodes = coordinates.shape[0] for sampling_ratio in sample: num_sample = int(num_nodes * sampling_ratio) choice = np.random.choice(num_nodes, num_sample, replace=True) coordinates_new = coordinates[choice] edge_index = radius_graph(coordinates_new, 100, None, True, 8) adj = sparse_to_dense(edge_index) adj = adj.numpy() coordinates_new = coordinates_new.numpy() G = nx.from_numpy_matrix(adj) x = dict(enumerate(coordinates_new[:, 0].tolist(), 0)) y = dict(enumerate(coordinates_new[:, 1].tolist(), 0)) nx.set_node_attributes(G, x, 'x') nx.set_node_attributes(G, y, 'y') nx.write_gexf( G, os.path.join( savepath, str(sampling_ratio) + numpyfile.split('/')[-1].replace('.npy', '.gexf')))
def get(self, idx): epoch = self.epoch if self.split == 'train' else self.val_epoch if self.dynamic_graph: data = torch.load(osp.join(self.processed_root, self.idxlist[idx])) else: data = torch.load( osp.join(self.processed_fix_data_root, str(epoch), self.idxlist[idx])) if self.feature_type == 'c': data.x = data.x[:, -2:] elif self.feature_type == 'a': data.x = data.x[:, :-2] if self.dynamic_graph: num_nodes = data.num_nodes dist_path = os.path.join(self.root, 'proto', 'distance', self.setting.dataset, self.idxlist[idx]) choice, _ = self._sampling(num_nodes, self.sampling_ratio, dist_path) for key, item in data: if torch.is_tensor(item) and item.size(0) == num_nodes: data[key] = item[choice] if self.graph_sampler == 'knn': edge_index = radius_graph(data.pos, self.max_edge_distance, None, True, self.max_neighbours) data.edge_index = edge_index else: raise NotImplementedError data.patch_idx = torch.tensor([idx]) data.x = (data.x - self.mean) / self.std return data
def __call__(self, data): edge_index = radius_graph(data.pos, self.radius, None, max_num_neighbors=64) edge_index = to_undirected(edge_index, num_nodes=data.pos.size(0)) data.edge_index = edge_index return data
def forward(self, data): self.device = data.pos.device self.num_atoms = len(data.batch) self.batch_size = len(data.natoms) atomic_numbers = data.atomic_numbers.long() pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 100) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: assert (atomic_numbers.dim() == 1 and atomic_numbers.dtype == torch.long) out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_distance_vec=True, ) edge_index = out["edge_index"] edge_distance = out["distances"] edge_distance_vec = out["distance_vec"] else: edge_index = radius_graph(pos, r=self.cutoff, batch=data.batch) j, i = edge_index edge_distance_vec = pos[j] - pos[i] edge_distance = edge_distance_vec.norm(dim=-1) edge_index, edge_distance, edge_distance_vec = self._filter_edges( edge_index, edge_distance, edge_distance_vec, self.max_num_neighbors, ) outputs = self._forward_helper(data, edge_index, edge_distance, edge_distance_vec) if self.show_timing_info is True: torch.cuda.synchronize() print("Memory: {}\t{}\t{}".format( len(edge_index[0]), torch.cuda.memory_allocated() / (1000 * len(edge_index[0])), torch.cuda.max_memory_allocated() / 1000000, )) return outputs
def __call__(self, data): data.edge_attr = None batch = data.batch if 'batch' in data else None edge_index = radius_graph(data.pos, self.r, batch, self.loop, self.max_num_neighbors) data.edge_index = edge_index return data
def generate_random_graph(n, radius, max_num_neighbors, batch_size, device): # arrange for points to fit more inside the unit square batch = None num_nodes = th.tensor([n], device=device, dtype=th.long) x = 0.90 * (th.rand(n, 2, device=device) + 0.05) edge_index = radius_graph( x, radius, batch=batch, loop=True, max_num_neighbors=max_num_neighbors) # COO format return x, edge_index, num_nodes, batch
def forward(self, data): # Get node features if self.embedding.device != data.atomic_numbers.device: self.embedding = self.embedding.to(data.atomic_numbers.device) data.x = self.embedding[data.atomic_numbers.long() - 1] pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device ) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, ) data.edge_index = out["edge_index"] distances = out["distances"] else: data.edge_index = radius_graph( data.pos, r=self.cutoff, batch=data.batch ) row, col = data.edge_index distances = (pos[row] - pos[col]).norm(dim=-1) data.edge_attr = self.distance_expansion(distances) # Forward pass through the network mol_feats = self._convolve(data) mol_feats = self.conv_to_fc(mol_feats) if hasattr(self, "fcs"): mol_feats = self.fcs(mol_feats) energy = self.fc_out(mol_feats) if self.regress_forces: forces = -1 * ( torch.autograd.grad( energy, pos, grad_outputs=torch.ones_like(energy), create_graph=True, )[0] ) return energy, forces else: return energy
def forward(self, data): # mesh = data[0] mesh = data if self.configs['euc_radius'] > 0: mesh.euc_edge_index = radius_graph(mesh.pos, self.configs['euc_radius']) else: mesh.euc_edge_index = None mesh_feat, global_feat = self.edgeconv_feat(mesh) skin_logits = self.skinnet(mesh_feat) return skin_logits
def forward(self, z, pos, batch=None): assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch h = self.embedding(z) edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=1000) row, col = edge_index edge_vec = pos[row] - pos[col] edge_sh = o3.spherical_harmonics(self.Rs_sh, edge_vec, 'component') / self.num_neighbors**0.5 edge_len = edge_vec.norm(dim=1) edge_weight = self.radial(edge_len) edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2) for conv, act, shortcut in self.layers[:-1]: with torch.autograd.profiler.record_function("Layer"): if shortcut: s = shortcut(h) h = conv(h, edge_index, edge_weight, edge_c, edge_sh) # convolution h = act(h) # gate non linearity if shortcut: m = shortcut.output_mask h = 0.5**0.5 * s + (1 + (0.5**0.5 - 1) * m) * h with torch.autograd.profiler.record_function("Layer"): h = self.layers[-1](h, edge_index, edge_weight, edge_c, edge_sh) s = 0 for i, (mul, l, p) in enumerate(self.Rs_out): assert mul == 1 and l == 0 if p == 1: s += h[:, i] if p == -1: s += h[:, i].pow(2).mul(0.5) # odd^2 = even h = s.view(-1, 1) if self.mean is not None and self.std is not None: h = h * self.std + self.mean if self.atomref is not None: h = h + self.atomref(z) out = scatter(h, batch, dim=0, reduce=self.readout) if self.scale is not None: out = self.scale * out return out
def __call__(self, data): edge_index = radius_graph(data.pos, r=self.cutoff, batch=data.batch) edge_index, _ = remove_self_loops(edge_index) row, col = edge_index[0], edge_index[1] edge_weight = (data.pos[row] - data.pos[col]).norm(dim=-1) if hasattr(data, "edge_attr"): pass else: data.edge_attr = edge_weight.reshape(-1, 1) return data
def forward(self, atomic_ns, edge_index, coords, batch_node_vec): # create initial node embs from atom charges node_embs = self.node_emb(atomic_ns) # create initial edge embs from coordinates edge_index = radius_graph(coords, self.cutoff, batch_node_vec) node_is, node_js = edge_index edge_weights = (coords[node_is] - coords[node_js]).norm(dim = -1) edge_embs = self.dist_emb(edge_weights) # edge_embs = interatomic dist embs return node_embs, edge_embs, edge_weights
def forward(self, data, phase): # pdb.set_trace() x, edge_index, pseudo, batch = data.x, data.edge_index, data.pos, data.batch radius = 4 edge_index = radius_graph(pseudo.float(), r=radius, batch=batch) x = self.lin01(x) x_g1 = F.relu(self.conv1(x, edge_index)) x = self.graph_norm_1(x_g1, batch) # pdb.set_trace() x = global_mean_pool(x, batch) x = self.lin2(x) x = self.lin3(x) # pdb.set_trace() return {'A2': x}
def forward(self, data): pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) batch = data.batch x = self.embedding(data.atomic_numbers.long()) edge_index = radius_graph(pos, r=self.cutoff, batch=batch) j, i = edge_index idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( edge_index, num_nodes=x.size(0)) # Calculate distances. dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() # Calculate angles. pos_i = pos[idx_i].detach() pos_ji, pos_ki = ( pos[idx_j].detach() - pos_i, pos[idx_k].detach() - pos_i, ) a = (pos_ji * pos_ki).sum(dim=-1) b = torch.cross(pos_ji, pos_ki).norm(dim=-1) angle = torch.atan2(b, a) rbf = self.rbf(dist) sbf = self.sbf(dist, angle, idx_kj) # Embedding block. x = self.emb(x, rbf, i, j) P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) # Interaction blocks. for interaction_block, output_block in zip(self.interaction_blocks, self.output_blocks[1:]): x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) P += output_block(x, rbf, i) energy = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0) if self.regress_forces: forces = -1 * (torch.autograd.grad( energy, pos, grad_outputs=torch.ones_like(energy), create_graph=True, )[0]) return energy, forces else: return energy
def __call__(self, data): #pdb.set_trace() nodes = data.x #global_node = torch.ones([1, data.x.size()[1]]) global_node = nodes.mean(dim = 0).unsqueeze(0) data.x = torch.cat([global_node, nodes]) pos = data.pos #pdb.set_trace() edge_index = radius_graph(pos, r = self.radius) zero_index = torch.stack((torch.arange(data.x.size()[0]), torch.zeros(data.x.size()[0]).long()))[:,1:] edge_index = torch.cat((zero_index, edge_index+1), dim = 1) data.pos = torch.cat((torch.zeros(1, 2).int(), pos+1)) data.edge_index = edge_index return data
def forward(self, data, phase): # pdb.set_trace() x, edge_index, pseudo, batch = data.x, data.edge_index, data.pos, data.batch # radius = np.random.randint(1,4) if phase == 'train' else 4 radius = 4 edge_index = radius_graph(pseudo.float(), r=radius, batch=batch) x = self.lin01(x) x = F.relu(self.conv1(x, edge_index)) x = self.graph_norm_1(x, batch) # pdb.set_trace() x = global_mean_pool(x, batch) x = self.lin2(x) x = self.lin3(x) # pdb.set_trace() return {'A2': x}
def __call__(self, data, sigma=0.01, clip=0.05): """ Randomly jitter points. jittering is per point. :param pc: B X N X 3 array, original batch of point clouds :param sigma: :param clip: :return: """ pc = data.pos jittered_data = torch.from_numpy( np.clip(sigma * np.random.randn(*pc.shape), -1 * clip, clip)) jittered_data += pc data.pos = jittered_data data.edge_index = nn.radius_graph(x=torch.tensor(pc), r=0.05) return data
def __call__(self, data: Data)->Data: data.edge_attr = None nans = torch.isnan(data.pos[:, 0]) pos = data.pos.clone() pos[nans] = 10_000 batch = data.batch if "batch" in data else None edge_index = radius_graph( pos, self.r, batch, self.loop, self.max_num_neighbors, ) drop = torch.any(edge_index[0] == torch.nonzero(nans), dim=0) | torch.any( edge_index[1] == torch.nonzero(nans), dim=0 ) data.edge_index = edge_index[:, ~drop] return data