def forward(self, data): if self.layer_num == 0: return data.x, 0, torch.zeros_like(data.x[:, 0:1]) x, batch = data.x, data.batch kwargs = {k: v for k, v in data.__dict__.items()} kwargs.pop('x') new_x = x left_confidence = torch.ones_like(x[:, 0:1]) residual_confidence = torch.ones_like(x[:, 0:1]) zero_mask = torch.zeros_like(x[:, 0:1]) for iter_num in range(self.layer_num): data = Batch(x=self.next_x(x, new_x, left_confidence, self.decreasing_ratio), **kwargs) new_x = self.gnn_layer_module(data) global_feat = self.readout_module(Batch(x=new_x, **kwargs)) current_confidence = self.confidence_module(global_feat)[batch] left_confidence = left_confidence - current_confidence * ( 1 - zero_mask) current_zero_mask = (left_confidence < 1e-7).type(torch.float) residual_confidence = residual_confidence - current_confidence * ( 1 - current_zero_mask) x = x + (current_confidence * (1 - current_zero_mask) + residual_confidence * current_zero_mask * (1 - zero_mask)) * new_x zero_mask = current_zero_mask if torch.min(zero_mask).item() > 0.5: break return x, iter_num, residual_confidence
def forward(self, data): if self.layer_num == 0: return data.x, 0 x, batch = data.x, data.batch kwargs = {k: v for k, v in data.__dict__.items()} kwargs.pop('x') new_x = x left_confidence = torch.ones_like(x[:, 0:1]) for iter_num in range(self.layer_num): if torch.max(left_confidence).item() > 1e-7: data = Batch(x=self.next_x(x, new_x, left_confidence, self.decreasing_ratio), **kwargs) new_x = self.gnn_layer_module(data) global_feat = self.readout_module(Batch(x=new_x, **kwargs)) current_confidence = self.confidence_module(global_feat)[batch] x = self.update_x(x if iter_num != 0 else torch.zeros_like(x), new_x, left_confidence, current_confidence, self.decreasing_ratio) left_confidence = self.update_confidence( left_confidence, current_confidence, self.decreasing_ratio) else: break return x, iter_num
def update(self, memory): # Monte Carlo estimate of rewards: rewards = [] discounted_reward = 0 for reward, terminal in zip(reversed(memory.rewards), reversed(memory.terminals)): if terminal: discounted_reward = 0 discounted_reward = reward + (self.gamma * discounted_reward) rewards.insert(0, discounted_reward) # Normalizing the rewards: rewards = torch.tensor(rewards).to(self.device) # candidates batch batch_idx = [] for i, cands in enumerate(memory.candidates): batch_idx.extend([i] * len(cands)) batch_idx = torch.LongTensor(batch_idx).to(self.device) # convert list to tensor states = [ Batch().from_data_list([state[i] for state in memory.states ]).to(self.device) for i in range(1 + self.use_3d) ] states_next = [ Batch().from_data_list( [state_next[i] for state_next in memory.states_next]).to(self.device) for i in range(1 + self.use_3d) ] candidates = [ Batch().from_data_list( [item[i] for sublist in memory.candidates for item in sublist]).to(self.device) for i in range(1 + self.use_3d) ] actions = torch.tensor(memory.actions).to(self.device) old_logprobs = torch.tensor(memory.logprobs).to(self.device) old_values = self.policy.get_value(states) # Optimize policy for k epochs: logging.info("Optimizing...") for i in range(1, self.k_epochs + 1): loss, baseline_loss = self.policy.update(states, candidates, actions, rewards, old_logprobs, old_values, batch_idx) rnd_loss = self.explore_critic.update(states_next) if (i % 10) == 0: logging.info( " {:3d}: Actor Loss: {:7.3f}, Critic Loss: {:7.3f}, RND Loss: {:7.3f}" .format(i, loss, baseline_loss, rnd_loss))
def set_input(self, data, device): self.input = Batch(pos=data.pos, x=data.x, batch=data.batch).to(device) if hasattr(data, "pos_target"): self.input_target = Batch(pos=data.pos_target, x=data.x_target, batch=data.batch_target).to(device) self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device) else: self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device)
def to_data_list(data): if 'to_data_list' in data.__dict__: return data.to_data_list() graph_indexes = set(data.batch.tolist()) data_list = [] for gi in graph_indexes: node_indexes = torch.arange(data.x.size(0), device=data.x.device) node_indexes = node_indexes[data.batch == gi] node_index_max, node_index_min = torch.max(node_indexes), torch.min(node_indexes) edge_indexes = (data.edge_index[0]>=node_index_min)&(data.edge_index[0]<=node_index_max) edge_indexes = torch.arange(data.edge_index.size(1), device=data.x.device)[edge_indexes] x = data.x[node_indexes] edge_index = data.edge_index[:,edge_indexes]-node_index_min edge_attr = data.edge_attr[edge_indexes] y = data.y[gi:gi+1] batch = torch.zeros_like(node_indexes) data_list.append(Batch( x=x, y=y, batch=batch, edge_index=edge_index, edge_attr=edge_attr, )) assert(data.x.size(0)==sum([d.x.size(0) for d in data_list])) assert(all([d.x.size(0) == d.batch.size(0) for d in data_list])) assert(data.edge_index.size(1)==sum([d.edge_index.size(1) for d in data_list])) assert(all([d.edge_index.size(1)==d.edge_attr.size(0) for d in data_list])) assert(all([data.x.size(1) == d.x.size(1) for d in data_list])) assert(all([data.edge_attr.size(1) == d.edge_attr.size(1) for d in data_list])) return data_list
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if self.encode_edge: x = self.atom_encoder(x) x = self.conv1(x, edge_index, data.edge_attr) else: x = self.conv1(x, edge_index) x = F.relu(x) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if self.pooling_type != 'none': if self.pooling_type == 'complement': complement = batched_negative_edges(edge_index=edge_index, batch=batch, force_undirected=True) cluster = graclus(complement, num_nodes=x.size(0)) elif self.pooling_type == 'graclus': cluster = graclus(edge_index, num_nodes=x.size(0)) data = Batch(x=x, edge_index=edge_index, batch=batch) data = max_pool(cluster, data) x, edge_index, batch = data.x, data.edge_index, data.batch if not self.no_cat: x = self.jump(xs) else: x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = self.lin2(x) return x
def forward(self, data): kwargs = {k: v for k, v in data.__dict__.items()} for _ in range(self.layer_num): data = Batch(**kwargs) kwargs['x'] = self.gnn_layer_module(data) assert (not torch.sum(torch.isnan(kwargs['x']))) return kwargs['x'], self.layer_num
def forward(self, data): kwargs = {k: v for k, v in data.__dict__.items()} for l in self.layers: data = Batch(**kwargs) kwargs['x'] = l(data) assert (not torch.sum(torch.isnan(kwargs['x']))) return kwargs['x'], len(self.layers)
def convert_data_to_batch(x): data_list = [] for xx in x: data_list.append(Data(pos=xx)) batch = Batch() return batch.from_data_list(data_list)
def x_pos_batch_to_pair_biggraph_pair(cloud_s_all, cloud_t_all, lss, lst): x_pos_s_all, x_pos_t_all, batch_s, batch_t = [], [], [], [] for i, (ls, lt, cloud_s, cloud_t) in enumerate(zip(lss, lst, cloud_s_all, cloud_t_all)): x_pos_s_all += [cloud_s[:ls, :]] x_pos_t_all += [cloud_t[:lt, :]] batch_s += [torch.ones(ls,).long().unsqueeze(1).to(lss.device) * i] batch_t += [torch.ones(lt,).long().unsqueeze(1).to(lst.device) * i] x_pos_s_all = torch.cat(x_pos_s_all, dim=0) x_pos_t_all = torch.cat(x_pos_t_all, dim=0) batch_s = torch.cat(batch_s, dim=0).squeeze() batch_t = torch.cat(batch_t, dim=0).squeeze() graph_s = Batch(x=x_pos_s_all[:, 2:3], pos=x_pos_s_all[:, :3], batch=batch_s) graph_t = Batch(x=x_pos_t_all[:, 2:3], pos=x_pos_t_all[:, :3], batch=batch_t) return graph_s, graph_t
def _prepare_batch(self, batch): """Creates batch data for MEGNet model Note ---- Ideally, we should only override default_generator method. But the problem here is that we _prepare_batch of TorchModel only supports non-graph data types. Hence, we are overriding it here. This should be fixed some time in the future. """ try: from torch_geometric.data import Batch except ModuleNotFoundError: raise ImportError("This module requires PyTorch Geometric") # We convert deepchem.feat.GraphData to a PyG graph and then # batch it. graphs, labels, weights = batch # The default_generator method returns an array of dc.feat.GraphData objects # nested inside a list. To access the nested array of graphs, we are # indexing by 0 here. graph_list = [graph.to_pyg_graph() for graph in graphs[0]] pyg_batch = Batch() pyg_batch = pyg_batch.from_data_list(graph_list) _, labels, weights = super(MEGNetModel, self)._prepare_batch( ([], labels, weights)) return pyg_batch, labels, weights
def tg_transform(args, X): batch_size = X.size(0) pos = X[:, :, :2] x1 = pos.repeat(1, 1, args.num_hits).reshape(batch_size, args.num_hits * args.num_hits, 2) x2 = pos.repeat(1, args.num_hits, 1) diff_norms = torch.norm(x2 - x1 + 1e-12, dim=2) norms = diff_norms.reshape(batch_size, args.num_hits, args.num_hits) neighborhood = torch.nonzero(norms < args.cutoff, as_tuple=False) neighborhood = neighborhood[neighborhood[:, 1] != neighborhood[:, 2]] # remove self-loops unique, counts = torch.unique(neighborhood[:, 0], return_counts=True) edge_index = (neighborhood[:, 1:] + (neighborhood[:, 0] * args.num_hits).view(-1, 1)).transpose(0, 1) x = X[:, :, 2].reshape(batch_size * args.num_hits, 1) + 0.5 pos = 28 * pos.reshape(batch_size * args.num_hits, 2) + 14 row, col = edge_index edge_attr = (pos[col] - pos[row]) / (2 * 28 * args.cutoff) + 0.5 zeros = torch.zeros(batch_size * args.num_hits, dtype=int).to(args.device) zeros[torch.arange(batch_size) * args.num_hits] = 1 batch = torch.cumsum(zeros, 0) - 1 return Batch(batch=batch, x=x, edge_index=edge_index.contiguous(), edge_attr=edge_attr, y=None, pos=pos)
def from_data_list_token(data_list, follow_batch=[]): """ This is pretty a copy paste of the from data list of pytorch geometric batch object with the difference that indexes that are negative are not incremented """ keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert "batch" not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch["{}_batch".format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: mask = item >= 0 item[mask] = item[mask] + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] += data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch["{}_batch".format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError("Unsupported attribute type {} : {}".format(type(item), item)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def sample_batch_pyg(data, sample_config): """ Perturb the structure and node attributes. Parameters ---------- data: torch_geometric.data.Batch Dataset containing the attributes, edge indices, and batch-ID sample_config: dict Configuration specifying the sampling probabilities Returns ------- per_data: torch_geometric.Dataset Dataset containing the perturbed graphs """ pf_plus_adj = sample_config.get('pf_plus_adj', 0) pf_plus_att = sample_config.get('pf_plus_att', 0) pf_minus_adj = sample_config.get('pf_minus_adj', 0) pf_minus_att = sample_config.get('pf_minus_att', 0) per_x = binary_perturb(data.x, pf_minus_att, pf_plus_att) per_edge_index = sparse_perturb_adj_batch(data_idx=data.edge_index, nnodes=torch.bincount( data.batch), pf_minus=pf_minus_adj, pf_plus=pf_plus_adj, undirected=True) per_data = Batch(batch=data.batch, x=per_x, edge_index=per_edge_index) return per_data
def avg_pool(cluster, data, transform=None): r"""Pools and coarsens a graph given by the :class:`torch_geometric.data.Data` object according to the clustering defined in :attr:`cluster`. Final node features are defined by the *average* features of all nodes within the same cluster. See :meth:`torch_geometric.nn.pool.max_pool` for more details. Args: cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. data (Data): Graph data object. transform (callable, optional): A function/transform that takes in the coarsened and pooled :obj:`torch_geometric.data.Data` object and returns a transformed version. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Data` """ cluster, perm = consecutive_cluster(cluster) x = None if data.x is None else _avg_pool_x(cluster, data.x) index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) batch = None if data.batch is None else pool_batch(perm, data.batch) pos = None if data.pos is None else pool_pos(cluster, data.pos) data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) if transform is not None: data = transform(data) return data
def forward(self, data): subgraph_data = subgraph_loader(data, k, super_node_size, num_tours, num_cpus) subgraphs = [ get_subgraph(data[subgraph_data.batch[i].item()], subgraph_data.subgraphs[i].squeeze()) for i in range(len(subgraph_data.subgraphs)) ] subgraphs_lst = [] for i in range(0, len(subgraphs), 500): subgraphs_b = Batch().from_data_list( subgraphs[i:i + min([500, len(subgraphs) - i])]) subgraphs_b = self.gnn_layer(subgraphs_b.x.cuda(), subgraphs_b.edge_index.cuda(), subgraphs_b.batch.cuda()) \ if next(self.parameters()).get_device() != -1 else self.gnn_layer(subgraphs_b.x, subgraphs_b.edge_index, subgraphs_b.batch) subgraphs_lst.append(subgraphs_b) subgraphs = torch.cat(subgraphs_lst, dim=0) subgraphs = self.output_layer(subgraphs) weights = subgraph_data.weights.cuda() if next( self.parameters()).get_device() != -1 else subgraph_data.weights batch = subgraph_data.batch.cuda() if next( self.parameters()).get_device() != -1 else subgraph_data.batch subgraphs = subgraphs * weights norm = global_add_pool(weights, batch) energy = global_add_pool(subgraphs, batch) return energy / norm
def forward(self, data): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch if self._precompute_multi_scale: idx = getattr(data, "idx_{}".format(self._index), None) else: idx = self.sampler(pos, batch) batch_obj.idx = idx ms_x = [] for scale_idx in range(self.neighbour_finder.num_scales): if self._precompute_multi_scale: edge_index = getattr( data, "edge_index_{}_{}".format(self._index, scale_idx), None) else: row, col = self.neighbour_finder( pos, pos[idx], batch_x=batch, batch_y=batch[idx], scale_idx=scale_idx, ) edge_index = torch.stack([col, row], dim=0) ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch)) batch_obj.x = torch.cat(ms_x, -1) batch_obj.pos = pos[idx] batch_obj.batch = batch[idx] copy_from_to(data, batch_obj) return batch_obj
def forward(self, data): batch_obj = Batch() data, data_skip = data x, pos, batch = data.x, data.pos, data.batch x_skip, pos_skip, batch_skip = data_skip.x, data_skip.pos, data_skip.batch if self.neighbour_finder is not None: if self._precompute_multi_scale: # TODO For now, it uses the one calculated during down steps edge_index = getattr(data_skip, "edge_index_{}".format(self._index), None) col, row = edge_index edge_index = torch.stack([row, col], dim=0) else: row, col = self.neighbour_finder(pos, pos_skip, batch, batch_skip) edge_index = torch.stack([col, row], dim=0) else: edge_index = None x = self.conv(x, pos, pos_skip, batch, batch_skip, edge_index) if x_skip is not None and self._skip: x = torch.cat([x, x_skip], dim=1) if hasattr(self, "nn"): batch_obj.x = self.nn(x) else: batch_obj.x = x copy_from_to(data_skip, batch_obj) return batch_obj
def buildGraph(feat, label): B = feat.shape[0] NoOfNodes = feat.shape[1] #feat.reshape(B,NoOfNodes,-1) edge_index = list(itertools.permutations(np.arange(0, NoOfNodes), 2)) edge_index = torch.LongTensor(edge_index).T listofData = [] for i in range(0, B): feat_arr = feat[i].detach().cpu().numpy().reshape(NoOfNodes, -1) edge_attr = np.asarray([ np.linalg.norm(a - b) for a, b in itertools.product(feat_arr, feat_arr) ]) # for a in feat_arr[i]: # for b in feat_arr[i]: # print(np.linalg.norm(a-b)) edge_attr = edge_attr[edge_attr > 0] edge_attr = torch.Tensor(edge_attr).view(-1) data = Data(x=torch.Tensor(feat_arr), edge_index=edge_index, edge_attr=edge_attr, y=label[i].view(-1)) listofData.append(data) batch = Batch().from_data_list(listofData) return batch
def forward(self, data, *args, **kwargs): """ Parameters: ----------- data A SparseTensor that contains the data itself and its metadata information. Should contain F -- Features [N, C] coords -- Coords [N, 4] Returns -------- data: - x [1, output_nc] """ self._set_input(data) data = self.input for i in range(len(self.down_modules)): data = self.down_modules[i](data) out = Batch(x=data.F, batch=data.C[:, 0].long().to(data.F.device)) if not isinstance(self.inner_modules[0], Identity): out = self.inner_modules[0](out) if self.has_mlp_head: out.x = self.mlp(out.x) return out
def mols_to_pyg_batch(mols, idm=False, ratio=2., device=None): if not isinstance(mols, list): mols = [mols] graphs = [mol_to_pyg_graph(mol, idm, ratio) for mol in mols] g1 = Batch().from_data_list([graph[0] for graph in graphs]) if device is not None: g1 = g1.to(device) if idm: g2 = Batch().from_data_list([graph[1] for graph in graphs]).to(device) if device is not None: g2 = g2.to(device) else: g2 = None return [g1, g2]
def forward(self, data): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch idx_sampler = self.sampler(pos=pos, x=x, batch=batch) idx_neighbour, _ = self.neighbour_finder(pos, pos, batch_x=batch, batch_y=batch) shadow_x = torch.full((1, ) + x.shape[1:], self.shadow_features_fill).to(x.device) shadow_points = torch.full((1, ) + pos.shape[1:], self.shadow_points_fill_).to(x.device) x = torch.cat([x, shadow_x], dim=0) pos = torch.cat([pos, shadow_points], dim=0) x_neighbour = x[idx_neighbour] pos_centered_neighbour = pos[idx_neighbour] - pos[:-1].unsqueeze( 1) # Centered the points batch_obj.x = self.conv(x, pos, x_neighbour, pos_centered_neighbour, idx_neighbour, idx_sampler) batch_obj.pos = pos[idx_sampler] batch_obj.batch = batch[idx_sampler] copy_from_to(data, batch_obj) return batch_obj
def test_graphnet_for_graphs_in_batch(): # Testing with a batch of Graphs try: from torch_geometric.data import Batch except ModuleNotFoundError: raise ImportError("Tests require pytorch geometric to be installed") n_node_features, n_edge_features, n_global_features = 3, 4, 5 fgg = FakeGraphGenerator(min_nodes=8, max_nodes=12, n_node_features=n_node_features, avg_degree=10, n_edge_features=n_edge_features, n_classes=2, task='graph', z=n_global_features) graphs = fgg.sample(n_graphs=10) graphnet = GraphNetwork(n_node_features, n_edge_features, n_global_features) graph_batch = Batch() graph_batch = graph_batch.from_data_list( [graph.to_pyg_graph() for graph in graphs.X]) new_node_features, new_edge_features, new_global_features = graphnet( graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr, graph_batch.z, graph_batch.batch) assert graph_batch.x.size() == new_node_features.size() assert graph_batch.edge_attr.size() == new_edge_features.size() assert graph_batch.z.size() == new_global_features.size()
def test_single_voxel_grid(): pos = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]) edge_index = torch.tensor([[0, 0, 3], [1, 2, 4]]) batch = torch.tensor([0, 0, 0, 1, 1]) x = torch.randn(5, 16) cluster = voxel_grid(pos, size=5, batch=batch) assert cluster.tolist() == [0, 0, 0, 1, 1] data = Batch(x=x, edge_index=edge_index, pos=pos, batch=batch) data = avg_pool(cluster, data) cluster_no_batch = voxel_grid(pos, size=5) assert cluster_no_batch.tolist() == [0, 0, 0, 0, 0] data_no_batch = Batch(x=x, edge_index=edge_index, pos=pos) data_no_batch = avg_pool(cluster_no_batch, data_no_batch)
def get_final_reward(state, env, surrogate_model, device): # g = state_to_graph(state, env, keep_self_edges=False) g = Batch().from_data_list([mol_to_pyg_graph(state)]) g = g.to(device) with torch.autograd.no_grad(): pred_docking_score = surrogate_model(g, None) reward = pred_docking_score.item() * -1 return reward
def tg_transform(args, X): batch_size = X.size(0) pos = X[:, :, :2] x1 = pos.repeat(1, 1, args.num_hits).reshape(batch_size, args.num_hits * args.num_hits, 2) x2 = pos.repeat(1, args.num_hits, 1) diff_norms = torch.norm(x2 - x1 + 1e-12, dim=2) # diff = x2-x1 # diff = diff[diff_norms < args.cutoff] norms = diff_norms.reshape(batch_size, args.num_hits, args.num_hits) neighborhood = torch.nonzero(norms < args.cutoff, as_tuple=False) # diff = diff[neighborhood[:, 1] != neighborhood[:, 2]] neighborhood = neighborhood[neighborhood[:, 1] != neighborhood[:, 2]] # remove self-loops unique, counts = torch.unique(neighborhood[:, 0], return_counts=True) # edge_slices = torch.cat((torch.tensor([0]).to(device), counts.cumsum(0))) edge_index = (neighborhood[:, 1:] + (neighborhood[:, 0] * args.num_hits).view(-1, 1)).transpose( 0, 1) # normalizing edge attributes # edge_attr_list = list() # for i in range(batch_size): # start_index = edge_slices[i] # end_index = edge_slices[i + 1] # temp = diff[start_index:end_index] # max = torch.max(temp) # temp = temp/(2 * max + 1e-12) + 0.5 # edge_attr_list.append(temp) # # edge_attr = torch.cat(edge_attr_list) # edge_attr = diff/(2 * args.cutoff) + 0.5 x = X[:, :, 2].reshape(batch_size * args.num_hits, 1) + 0.5 pos = 28 * pos.reshape(batch_size * args.num_hits, 2) + 14 row, col = edge_index edge_attr = (pos[col] - pos[row]) / (2 * 28 * args.cutoff) + 0.5 zeros = torch.zeros(batch_size * args.num_hits, dtype=int).to(args.device) zeros[torch.arange(batch_size) * args.num_hits] = 1 batch = torch.cumsum(zeros, 0) - 1 return Batch(batch=batch, x=x, edge_index=edge_index.contiguous(), edge_attr=edge_attr, y=None, pos=pos)
def embedding(self, subgraphs): with torch.no_grad(): subgraphs_lst = [] for i in range(0, len(subgraphs), 500): subgraphs_b = Batch().from_data_list(subgraphs[i:i+min([500,len(subgraphs)-i])]) subgraphs_b = self.gnn_layer(subgraphs_b.x.cuda(), subgraphs_b.edge_index.cuda(), subgraphs_b.batch.cuda()) \ if next(self.parameters()).get_device() != -1 else self.gnn_layer(subgraphs_b.x, subgraphs_b.edge_index, subgraphs_b.batch) subgraphs_lst.append(subgraphs_b) subgraphs = torch.cat(subgraphs_lst,dim=0) return subgraphs
def forward(self, data, **kwargs): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch x = self.nn(torch.cat([x, pos], dim=1)) x = self.pool(x, batch) batch_obj.x = x batch_obj.pos = pos.new_zeros((x.size(0), 3)) batch_obj.batch = torch.arange(x.size(0), device=batch.device) copy_from_to(data, batch_obj) return batch_obj
def convert_to_batch(args, data, batch_size): zeros = torch.zeros(batch_size * args.num_hits, dtype=int).to(args.device) zeros[torch.arange(batch_size) * args.num_hits] = 1 batch = torch.cumsum(zeros, 0) - 1 return Batch(batch=batch, x=data.x, pos=data.pos, edge_index=data.edge_index, edge_attr=data.edge_attr)
def forward(self, data, output_node_feat_flag=False, output_layer_num_flag=False, output_residual_confidence_flag=False): kwargs = {k: v for k, v in data.__dict__.items()} kwargs['input_x'] = x = kwargs['x'] kwargs.pop('x') x = self.embedding_module(x) layer_num = 0 x_list = [] residual_confidence_list = [] for gnn_module in self.gnn_module_list: if 'ACT' in gnn_module.__class__.__name__: x, cur_layer_num, cur_residual_confidence = gnn_module( Batch(x=x, **kwargs)) residual_confidence_list.append(cur_residual_confidence) else: x, cur_layer_num = gnn_module(Batch(x=x, **kwargs)) layer_num += cur_layer_num x_list.append(x) if len(residual_confidence_list): residual_confidence = torch.sum(torch.stack( residual_confidence_list, dim=0), dim=0) if self.pointwise_head_layer_flag: x_list = [self.head_module(x) for x in x_list] global_feat = self.readout([Batch(x=x, **kwargs) for x in x_list]) # To avoid information-leak between nodes, we perform pointwise head-module for the physical simulation task if not self.pointwise_head_layer_flag: out = self.head_module(global_feat) else: out = global_feat output = (out, ) if output_node_feat_flag: output = output + (x, ) if output_layer_num_flag: output = output + (layer_num, ) if output_residual_confidence_flag: output = output + (residual_confidence, ) return output