def forward(self, inputs: ElementsToSummaryRepresentationInput) -> torch.Tensor: return scatter( src=inputs.element_embeddings, index=inputs.element_to_sample_map, dim=0, dim_size=inputs.num_samples, reduce=self.__summarization_type, )
def find_rate(edge_index): E = edge_index.size(1) src = torch.ones(E).to(edge_index.device) deg_hist = scatter(src, edge_index[1], reduce ='sum') min_deg = deg_hist.min() if min_deg == 0: return 1 / deg_hist.max() else: return min_deg / deg_hist.max()
def aggregate(self, inputs: Tensor, index: Tensor) -> Tensor: # Step 4: Sum by vertex or by edge dim_size = int(index.max()) + 1 return scatter(inputs, index, dim=0, dim_size=dim_size, reduce=self.aggr)
def forward(self, z, pos, batch=None): assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch h = self.embedding(z) edge_index = radius_graph(pos, r=self.cutoff, batch=batch) row, col = edge_index edge_weight = (pos[row] - pos[col]).norm(dim=-1) e = self.distance_expansion(edge_weight) h0 = h.clone() s_t = None for i in range(self.num_interactions): e = self.edge_updates[i](h, edge_index, e) msg = self.msg_passes[i](h, edge_index, e) if self.hypernet_update: s_t = self.state_transitions[i](h0, h, msg) else: s_t = self.state_transitions[i](msg) h = h + s_t h = self.fc1(h) h = self.act(h) h = self.fc2(h) if self.dipole: # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) h = h * (pos - c[batch]) if not self.dipole and self.mean is not None and self.std is not None: h = h * self.std + self.mean if not self.dipole and self.atomref is not None: h = h + self.atomref(z) out = scatter(h, batch, dim=0, reduce=self.readout) if self.dipole: out = torch.norm(out, dim=-1, keepdim=True) if self.scale is not None: out = self.scale * out return out
def correctness(dataset): group, name = dataset mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) dim_size = rowptr.size(0) - 1 for size in sizes: try: x = torch.randn((row.size(0), size), device=args.device) x = x.squeeze(-1) if size == 1 else x out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add') out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') out3 = segment_csr(x, rowptr, reduce='add') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean') out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') out3 = segment_csr(x, rowptr, reduce='mean') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min') out2 = segment_coo(x, row, reduce='min') out3 = segment_csr(x, rowptr, reduce='min') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max') out2 = segment_coo(x, row, reduce='max') out3 = segment_csr(x, rowptr, reduce='max') assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4) except RuntimeError as e: if 'out of memory' not in str(e): raise RuntimeError(e) torch.cuda.empty_cache()
def get_energy(batch, atomref): if batch.y is None: raise MissingEnergyException() if atomref is None: return batch.y.clone() # remove atomref energies from the target energy atomref_energy = scatter(atomref[batch.z], batch.batch, dim=0) return (batch.y.squeeze() - atomref_energy.squeeze()).clone()
def extract_node_feature(data, reduce='add'): if reduce in ['mean', 'max', 'add']: data.x = scatter(data.edge_attr, data.edge_index[0], dim=0, dim_size=data.num_nodes, reduce=reduce) else: raise Exception('Unknown Aggregation Type') return data
def forward_pyg(self, x, adj): row, col = adj.coalesce().indices() A = x[row] B = x[col] sim = self.beta * cosine_similarity(A, B) P = softmax(sim, row) src = x[row] * P.view(-1, 1) out = scatter(src, col, dim=0, reduce="add") return out
def aggregate(self, inputs, index, dim_size): # pragma: no cover r"""Aggregates messages from neighbors as :math:`\square_{j \in \mathcal{N}(i)}`. By default, delegates call to scatter functions that support "add", "mean" and "max" operations specified in :meth:`__init__` by the :obj:`aggr` argument. """ return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
def _mix_into_atoms(self, x: torch.Tensor, x_paths: PathActivationCollection, atom_to_path_map) -> torch.Tensor: x_out = x for key in self.path_keys: # Gather data for that path length row, col = atom_to_path_map[key] lin_k = self.path_to_atom_linear[key] x_path_k = x_paths[key] x_sc_k = scatter(x_path_k[col], row, dim=0, dim_size=x.size(0), reduce='mean') x_out = x_out + F.relu(lin_k(x_sc_k)) return x_out
def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: out_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum") return out_mean
def forward(self, x, rbf, i, num_nodes=None): x = self.lin_rbf(rbf) * x x = scatter(x, i, dim=0, dim_size=num_nodes ) # x = tf.math.unsorted_segment_sum(x, idnb_i, n_atoms) x = self.xupproj(x) # this imply that the lin have to change to for lin in self.lins: x = self.act(lin(x)) # return final dense layer done return self.final_lin(x)
def forward(self, data): cluster = nn_geometric.voxel_grid( data.pos, data.batch, self.pool_rad, start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5, end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5) cluster, perm = consecutive_cluster(cluster) data.x = scatter(data.x, cluster, dim=0, reduce=self.aggr) data.pos = scatter(data.pos, cluster, dim=0, reduce='mean') data.batch = data.batch[perm] data.edge_attr = None data.edge_index = None return data
def aggregate(self, inputs, index, ptr=None, dim_size=None): if ptr is not None: for _ in range(self.node_dim): ptr = ptr.unsqueeze(0) aggr_mean = segment_csr(inputs, ptr, reduce='mean') aggr_max = segment_csr(inputs, ptr, reduce='max') else: aggr_mean = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='mean') aggr_max = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='max') return torch.cat([aggr_mean, aggr_max], dim=-1)
def forward(self, x: Tensor, batch: Optional[Tensor] = None, dim_size: Optional[int] = None) -> Tensor: """""" if batch is None: return x - x.mean(dim=0, keepdim=True) mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='mean') return x - mean[batch]
def forward(self, z, pos, batch=None): """""" assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch h = self.embedding(z) edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) row, col = edge_index edge_weight = (pos[row] - pos[col]).norm(dim=-1) edge_attr = self.distance_expansion(edge_weight) for interaction in self.interactions: h = h + interaction(h, edge_index, edge_weight, edge_attr) h = self.lin1(h) h = self.act(h) h = self.lin2(h) if self.dipole: # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) h = h * (pos - c[batch]) if not self.dipole and self.mean is not None and self.std is not None: h = h * self.std + self.mean if not self.dipole and self.atomref is not None: h = h + self.atomref(z) out = scatter(h, batch, dim=0, reduce=self.readout) if self.dipole: out = torch.norm(out, dim=-1, keepdim=True) if self.scale is not None: out = self.scale * out return out
def softmax( src: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, dim: int = 0, ) -> Tensor: r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor, optional): The indices of elements for applying the softmax. (default: :obj:`None`) ptr (LongTensor, optional): If given, computes the softmax based on sorted inputs in CSR representation. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) dim (int, optional): The dimension in which to normalize. (default: :obj:`0`) :rtype: :class:`Tensor` """ if ptr is not None: dim = dim + src.dim() if dim < 0 else dim size = ([1] * dim) + [-1] ptr = ptr.view(size) src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr) out = (src - src_max).exp() out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) elif index is not None: N = maybe_num_nodes(index, num_nodes) src_max = scatter(src, index, dim, dim_size=N, reduce='max') src_max = src_max.index_select(dim, index) out = (src - src_max).exp() out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') out_sum = out_sum.index_select(dim, index) else: raise NotImplementedError return out / (out_sum + 1e-16)
def forward(self, x_h, x_g, edge_index, edge_attr, u, batch_g): src, tgt = edge_index out = edge_attr out = torch.cat([x_h[src], edge_attr], dim=1) out = self.node_mlp_1(out) ns = torch.ones(len(out), 1).float().cuda() a = scatter(out, tgt, dim=0, dim_size=x_g.size(0), reduce='sum') # mu out = torch.cat([x_g, a, u[batch_g]], dim=1) out = self.node_mlp_2(out) return out
def test_scatter(reduce): torch.manual_seed(12345) src = torch.randn(8, 100, 32) index = torch.randint(0, 10, (100, ), dtype=torch.long) with torch_geometric.experimental_mode('scatter_reduce'): out1 = scatter(src, index, dim=1, reduce=reduce) out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce) assert torch.allclose(out1, out2, atol=1e-6)
def forward(self, z, pos, batch=None): assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch h = self.embedding(z) edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=1000) row, col = edge_index edge_vec = pos[row] - pos[col] edge_sh = o3.spherical_harmonics(self.Rs_sh, edge_vec, 'component') / self.num_neighbors**0.5 edge_len = edge_vec.norm(dim=1) edge_weight = self.radial(edge_len) edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2) for conv, act, shortcut in self.layers[:-1]: with torch.autograd.profiler.record_function("Layer"): if shortcut: s = shortcut(h) h = conv(h, edge_index, edge_weight, edge_c, edge_sh) # convolution h = act(h) # gate non linearity if shortcut: m = shortcut.output_mask h = 0.5**0.5 * s + (1 + (0.5**0.5 - 1) * m) * h with torch.autograd.profiler.record_function("Layer"): h = self.layers[-1](h, edge_index, edge_weight, edge_c, edge_sh) s = 0 for i, (mul, l, p) in enumerate(self.Rs_out): assert mul == 1 and l == 0 if p == 1: s += h[:, i] if p == -1: s += h[:, i].pow(2).mul(0.5) # odd^2 = even h = s.view(-1, 1) if self.mean is not None and self.std is not None: h = h * self.std + self.mean if self.atomref is not None: h = h + self.atomref(z) out = scatter(h, batch, dim=0, reduce=self.readout) if self.scale is not None: out = self.scale * out return out
def forward(self, data): z = data.atomic_numbers.long() pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) batch = data.batch if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors # TODO return distance computation in radius_graph_pbc to remove need # for get_pbc_distances call if self.use_pbc: assert z.dim() == 1 and z.dtype == torch.long out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, ) edge_index = out["edge_index"] edge_weight = out["distances"] edge_attr = self.distance_expansion(edge_weight) h = self.embedding(z) for interaction in self.interactions: h = h + interaction(h, edge_index, edge_weight, edge_attr) h = self.lin1(h) h = self.act(h) h = self.lin2(h) batch = torch.zeros_like(z) if batch is None else batch energy = scatter(h, batch, dim=0, reduce=self.readout) else: energy = super(SchNetWrap, self).forward(z, pos, batch) if self.regress_forces: forces = -1 * (torch.autograd.grad( energy, pos, grad_outputs=torch.ones_like(energy), create_graph=True, )[0]) return energy, forces else: return energy
def aggregate(self, inputs, index, dim_size=None): # The axis along which to index number of nodes. node_dim = self.node_dim out = torch_scatter.scatter(inputs, index=index, dim=node_dim, reduce='mean', dim_size=dim_size) return out
def forward(self, x_h, x_g, edge_index, edge_attr, u, batch_h): src, tgt = edge_index out = edge_attr out = torch.cat([x_g[tgt], edge_attr], dim=1) out = self.node_mlp_1(out) ns = torch.ones(len(out), 1).float().cuda() n = scatter(ns, src, dim=0, dim_size=x_h.size(0), reduce='sum') # num a = scatter(out, src, dim=0, dim_size=x_h.size(0), reduce='mean') # mu b = torch.sqrt(1e-6 + F.relu( scatter(out**2, src, dim=0, dim_size=x_h.size(0), reduce='mean') - a**2)) # sigma c = scatter( (out - a[src])**3, src, dim=0, dim_size=x_h.size(0), reduce='mean') / b**3 #skewness d = scatter( (out - a[src])**4, src, dim=0, dim_size=x_h.size(0), reduce='mean') / b**4 #kurtosis out = torch.cat([x_h, n, a, b, c, d, u[batch_h]], dim=1) out = self.node_mlp_2(out) return out
def forward(self, data) -> torch.Tensor: node_atom = data['z'] node_pos = data['pos'] batch = data['batch'] # The graph edge_src, edge_dst = radius_graph(node_pos, r=self.max_radius, batch=batch, max_num_neighbors=1000) # Edge attributes edge_vec = node_pos[edge_src] - node_pos[edge_dst] edge_sh = o3.spherical_harmonics(l=range(self.sh_lmax + 1), x=edge_vec, normalize=True, normalization='component') # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.num_basis, basis='smooth_finite', cutoff=True, ).mul(self.num_basis**0.5) node_input = node_pos.new_ones(node_pos.shape[0], 1) node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom] node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5) node_outputs = self.mp(node_features=node_input, node_attr=node_attr, edge_src=edge_src, edge_dst=edge_dst, edge_attr=edge_sh, edge_scalars=edge_length_embedding) node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5) node_outputs = node_outputs.view(-1, 1) node_outputs = node_outputs.div(self.num_nodes**0.5) if self.atomref is not None: node_outputs = node_outputs + self.atomref[node_atom] # for target=7, MAE of 75eV outputs = scatter(node_outputs, batch, dim=0) return outputs
def determine_step(dr): steplengths = torch.norm(dr, dim=1) longest_steps = scatter(steplengths, self.atoms.batch, reduce="max") longest_steps = longest_steps[self.atoms.batch] maxstep = longest_steps.new_tensor(self.maxstep) scale = (longest_steps + 1e-7).reciprocal() * torch.min( longest_steps, maxstep) dr *= scale.unsqueeze(1) return dr * self.damping
def get_node(x, segment, mode='mean'): assert x.ndim == 3 and segment.ndim == 2 if isinstance(x, np.ndarray): x = torch.from_numpy(x) if isinstance(segment, np.ndarray): segment = torch.from_numpy(segment).to(torch.long) c = x.shape[2] x = x.reshape((-1, c)) mask = segment.flatten() nodes = scatter(x, mask, dim=0, reduce=mode) return nodes.to(torch.float32)
def aggregate(self, inputs, index, ptr=None, dim_size=None): if self.aggr in ['add', 'mean', 'max', None]: return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size) elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']: if self.learn_t: out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) else: with torch.no_grad(): out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) out = scatter(inputs*out, index, dim=self.node_dim, dim_size=dim_size, reduce='sum') if self.aggr == 'softmax_sum': self.sigmoid_y = torch.sigmoid(self.y) degrees = degree(index, num_nodes=dim_size).unsqueeze(1) out = torch.pow(degrees, self.sigmoid_y) * out return out elif self.aggr in ['power', 'power_sum']: min_value, max_value = 1e-7, 1e1 torch.clamp_(inputs, min_value, max_value) out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim, dim_size=dim_size, reduce='mean') torch.clamp_(out, min_value, max_value) out = torch.pow(out, 1/self.p) if self.aggr == 'power_sum': self.sigmoid_y = torch.sigmoid(self.y) degrees = degree(index, num_nodes=dim_size).unsqueeze(1) out = torch.pow(degrees, self.sigmoid_y) * out return out else: raise NotImplementedError('To be implemented')
def _compute_score(self, all_clusters, backbone_features, semantic_logits): """ Score the clusters """ if self._activate_scorer: x = [] coords = [] batch = [] for i, cluster in enumerate(all_clusters): x.append(backbone_features[cluster]) coords.append(self.input.coords[cluster]) batch.append(i * torch.ones(cluster.shape[0])) batch_cluster = Data( x=torch.cat(x).cpu(), coords=torch.cat(coords).cpu(), batch=torch.cat(batch).cpu(), ) score_backbone_out = self.Scorer(batch_cluster) if self._scorer_is_encoder: cluster_feats = score_backbone_out.x else: cluster_feats = scatter(score_backbone_out.x, score_backbone_out.batch.long().to( self.device), dim=0, reduce="max") cluster_scores = self.ScorerHead(cluster_feats).squeeze(-1) else: # Use semantic certainty as cluster confidence with torch.no_grad(): cluster_semantic = [] batch = [] for i, cluster in enumerate(all_clusters): cluster_semantic.append(semantic_logits[cluster, :]) batch.append(i * torch.ones(cluster.shape[0])) cluster_semantic = torch.cat(cluster_semantic) batch = torch.cat(batch) cluster_semantic = scatter(cluster_semantic, batch.long().to(self.device), dim=0, reduce="mean") cluster_scores = torch.max(cluster_semantic, 1)[0] return cluster_scores
def forward(self, ps, JsorShapes, ws, poses, batch, check_rotation=True, is_Rotation=False): batch_num = poses.shape[0] assert (batch_num == JsorShapes.shape[0]) if JsorShapes.shape.numel() == batch_num * 10: #is shapes Js = self.smpl.skeleton(JsorShapes) else: Js = JsorShapes # Rs = batch_rodrigues(poses.view(-1, 3)).view(-1, 24, 3, 3) if poses.numel() == batch_num * 24 * 3: Rs = batch_rodrigues(poses.view(-1, 3)).view(-1, 24, 3, 3) Js_transformed, A = batch_global_rigid_transformation( Rs, Js, self.smpl.parents, rotate_base=False) elif poses.numel() == batch_num * 24 * 9: #input poses are general matrix if not is_Rotation: ms = poses.reshape(-1, 3, 3) # use gram schmit regularization b1 = F.normalize(ms[:, :, 0], dim=1) dot_prod = torch.sum(b1 * ms[:, :, 1], dim=1, keepdim=True) b2 = F.normalize(ms[:, :, 1] - dot_prod * b1, dim=-1) b3 = torch.cross(b1, b2, dim=1) Rs = torch.stack([b1, b2, b3], dim=-1).reshape(batch_num, 24, 3, 3) else: Rs = poses.reshape(batch_num, 24, 3, 3) Js_transformed, A = batch_global_rigid_transformation( Rs, Js, self.smpl.parents, rotate_base=False) elif poses.numel() == batch_num * 24 * 16: A = poses.reshape(batch_num, 24, 4, 4) Js_transformed = None Rs = None # Js_transformed, A = batch_global_rigid_transformation(Rs, Js, self.smpl.parents, rotate_base = False) splitl = torch_scatter.scatter(batch.new_ones(batch.numel(), 1), batch, dim=0).cpu().numpy().reshape(-1).astype( np.int32).tolist() ws = ws.split(splitl, 0) T = torch.cat( [weight.matmul(a.reshape(24, 16)) for weight, a in zip(ws, A)], dim=0) T = T.reshape(-1, 4, 4) ps = torch.cat((ps, ps.new_ones(ps.shape[0], 1)), dim=-1).unsqueeze(-1) ps = torch.matmul(T, ps).squeeze(-1) return ps[:, 0:3], T, Rs, Js_transformed
def rand_prop(self, x, edge_index, edge_weight): edge_weight = self.normalize_adj(edge_index, edge_weight, x.shape[0]) row, col = edge_index[0], edge_index[1] x = self.dropNode(x) y = x for i in range(self.order): x_source = x[col] x = scatter(x_source * edge_weight[:, None], row[:,None], dim=0, dim_size=x.shape[0], reduce='sum').detach_() #x = torch.spmm(adj, x).detach_() y.add_(x) return y.div_(self.order + 1.0).detach_()