class schnet(torch.nn.Module): def __init__(self, energy_and_force, cutoff=10.0, num_layers=6, hidden_channels=128, num_filters=128, num_gaussians=50): super(schnet, self).__init__() self.energy_and_force = energy_and_force self.cutoff = cutoff self.num_layers = num_layers self.hidden_channels = hidden_channels self.num_filters = num_filters self.num_gaussians = num_gaussians self.init_v = Embedding(100, hidden_channels) self.dist_emb = emb(0.0, cutoff, num_gaussians) self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters) for _ in range(num_layers)]) self.update_es = torch.nn.ModuleList([ update_e(hidden_channels, num_filters, num_gaussians, cutoff) for _ in range(num_layers)]) self.update_u = update_u(hidden_channels) self.reset_parameters() def reset_parameters(self): self.init_v.reset_parameters() for update_e in self.update_es: update_e.reset_parameters() for update_v in self.update_vs: update_v.reset_parameters() self.update_u.reset_parameters() def forward(self, batch_data): z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch if self.energy_and_force: pos.requires_grad_() edge_index = radius_graph(pos, r=self.cutoff, batch=batch) row, col = edge_index dist = (pos[row] - pos[col]).norm(dim=-1) dist_emb = self.dist_emb(dist) v = self.init_v(z) for update_e, update_v in zip(self.update_es, self.update_vs): e = update_e(v, dist, dist_emb, edge_index) v = update_v(v,e, edge_index) u = self.update_u(v, batch) return u
class SpanRepresentation(Module): def __init__(self, config, d_output, vocab): super(SpanRepresentation, self).__init__() self.config = config self.vocab = vocab n_input = len(vocab) self.embedding = Embedding(n_input, config.d_embed) self.normalize_pretrained = getattr(config, 'normalize_pretrained', False) self.contextualizer = LSTMContextualizer( config) if config.n_lstm_layers > 0 else lambda x: x self.dropout = Dropout(p=config.dropout) self.head_attention = Sequential(self.dropout, Linear(2 * config.d_lstm_hidden, 1)) self.head_transform = Sequential( self.dropout, Linear(2 * config.d_lstm_hidden, d_output)) self.init() def init(self): [xavier_normal(p) for p in self.parameters() if len(p.size()) > 1] if self.vocab.vectors is not None: pretrained = normalize( self.vocab.vectors, dim=-1) if self.normalize_pretrained else self.vocab.vectors self.embedding.weight.data.copy_(pretrained) print( 'Copied pretrained vectors into relation span representation') else: #xavier_normal(self.embedding.weight.data) self.embedding.reset_parameters() def forward(self, inputs): text, mask = inputs text = self.dropout(self.embedding(text)) text = self.contextualizer(text) weights = masked_softmax( self.head_attention(text).squeeze(-1), mask.float()) representation = (weights.unsqueeze(2) * self.head_transform(text)).sum(dim=1) return representation
class EmbGCN(torch.nn.Module): def __init__(self, num_layers=2, hidden=32, emb_dim=64, num_class=2, num_nodes=None, **kwargs): super(EmbGCN, self).__init__() hidden = max(hidden, num_class * 2) self.conv1 = GCNConv(emb_dim, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.lin2 = Linear(hidden, num_class) self.emb = Embedding(num_nodes, emb_dim) self.first_lin = Linear(emb_dim, hidden) def reset_parameters(self): self.first_lin.reset_parameters() self.emb.reset_parameters() self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight, node_index = data.x, data.edge_index, data.edge_weight, data.node_index x = self.emb(node_index) x1 = F.elu(self.conv1(x, edge_index, edge_weight=edge_weight)) x = F.elu(self.first_lin(x)) x = F.dropout(x, p=0.5, training=self.training) for conv in self.convs: x = F.elu(conv(x, edge_index, edge_weight=edge_weight)) x = x1 + x x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class SchNet(torch.nn.Module): r"""The continuous-filter convolutional neural network SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" <https://arxiv.org/abs/1706.08566>`_ paper that uses the interactions blocks of the form .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), here :math:`h_{\mathbf{\Theta}}` denotes an MLP and :math:`\mathbf{e}_{j,i}` denotes the interatomic distances between atoms. .. note:: For an example of using a pretrained SchNet variant, see `examples/qm9_pretrained_schnet.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ qm9_pretrained_schnet.py>`_. Args: hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`) num_filters (int, optional): The number of filters to use. (default: :obj:`128`) num_interactions (int, optional): The number of interaction blocks. (default: :obj:`6`) num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`) cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) readout (string, optional): Whether to apply :obj:`"add"` or :obj:`"mean"` global aggregation. (default: :obj:`"add"`) dipole (bool, optional): If set to :obj:`True`, will use the magnitude of the dipole moment to make the final prediction, *e.g.*, for target 0 of :class:`torch_geometric.datasets.QM9`. (default: :obj:`False`) mean (float, optional): The mean of the property to predict. (default: :obj:`None`) std (float, optional): The standard deviation of the property to predict. (default: :obj:`None`) atomref (torch.Tensor, optional): The reference of single-atom properties. Expects a vector of shape :obj:`(max_atomic_number, )`. """ url = 'http://www.quantum-machine.org/datasets/trained_schnet_models.zip' def __init__(self, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[torch.Tensor] = None): super(SchNet, self).__init__() import ase self.hidden_channels = hidden_channels self.num_filters = num_filters self.num_interactions = num_interactions self.num_gaussians = num_gaussians self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.readout = readout self.dipole = dipole self.readout = 'add' if self.dipole else self.readout self.mean = mean self.std = std self.scale = None atomic_mass = torch.from_numpy(ase.data.atomic_masses) self.register_buffer('atomic_mass', atomic_mass) self.embedding = Embedding(100, hidden_channels) self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians) self.interactions = ModuleList() for _ in range(num_interactions): block = InteractionBlock(hidden_channels, num_gaussians, num_filters, cutoff) self.interactions.append(block) self.lin1 = Linear(hidden_channels, hidden_channels // 2) self.act = ShiftedSoftplus() self.lin2 = Linear(hidden_channels // 2, 1) self.register_buffer('initial_atomref', atomref) self.atomref = None if atomref is not None: self.atomref = Embedding(100, 1) self.atomref.weight.data.copy_(atomref) self.reset_parameters() def reset_parameters(self): self.embedding.reset_parameters() for interaction in self.interactions: interaction.reset_parameters() torch.nn.init.xavier_uniform_(self.lin1.weight) self.lin1.bias.data.fill_(0) torch.nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0) if self.atomref is not None: self.atomref.weight.data.copy_(self.initial_atomref) @staticmethod def from_qm9_pretrained(root: str, dataset: Dataset, target: int): import ase import schnetpack as spk # noqa assert target >= 0 and target <= 12 units = [1] * 12 units[0] = ase.units.Debye units[1] = ase.units.Bohr**3 units[5] = ase.units.Bohr**2 root = osp.expanduser(osp.normpath(root)) makedirs(root) folder = 'trained_schnet_models' if not osp.exists(osp.join(root, folder)): path = download_url(SchNet.url, root) extract_zip(path, root) os.unlink(path) name = f'qm9_{qm9_target_dict[target]}' path = osp.join(root, 'trained_schnet_models', name, 'split.npz') split = np.load(path) train_idx = split['train_idx'] val_idx = split['val_idx'] test_idx = split['test_idx'] # Filter the splits to only contain characterized molecules. idx = dataset.data.idx assoc = idx.new_empty(idx.max().item() + 1) assoc[idx] = torch.arange(idx.size(0)) train_idx = assoc[train_idx[np.isin(train_idx, idx)]] val_idx = assoc[val_idx[np.isin(val_idx, idx)]] test_idx = assoc[test_idx[np.isin(test_idx, idx)]] path = osp.join(root, 'trained_schnet_models', name, 'best_model') with warnings.catch_warnings(): warnings.simplefilter('ignore') state = torch.load(path, map_location='cpu') net = SchNet(hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0, atomref=dataset.atomref(target)) net.embedding.weight = state.representation.embedding.weight for int1, int2 in zip(state.representation.interactions, net.interactions): int2.mlp[0].weight = int1.filter_network[0].weight int2.mlp[0].bias = int1.filter_network[0].bias int2.mlp[2].weight = int1.filter_network[1].weight int2.mlp[2].bias = int1.filter_network[1].bias int2.lin.weight = int1.dense.weight int2.lin.bias = int1.dense.bias int2.conv.lin1.weight = int1.cfconv.in2f.weight int2.conv.lin2.weight = int1.cfconv.f2out.weight int2.conv.lin2.bias = int1.cfconv.f2out.bias net.lin1.weight = state.output_modules[0].out_net[1].out_net[0].weight net.lin1.bias = state.output_modules[0].out_net[1].out_net[0].bias net.lin2.weight = state.output_modules[0].out_net[1].out_net[1].weight net.lin2.bias = state.output_modules[0].out_net[1].out_net[1].bias mean = state.output_modules[0].atom_pool.average net.readout = 'mean' if mean is True else 'add' dipole = state.output_modules[0].__class__.__name__ == 'DipoleMoment' net.dipole = dipole net.mean = state.output_modules[0].standardize.mean.item() net.std = state.output_modules[0].standardize.stddev.item() if state.output_modules[0].atomref is not None: net.atomref.weight = state.output_modules[0].atomref.weight else: net.atomref = None net.scale = 1. / units[target] return net, (dataset[train_idx], dataset[val_idx], dataset[test_idx]) def forward(self, z, pos, batch=None): """""" assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch h = self.embedding(z) edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) row, col = edge_index edge_weight = (pos[row] - pos[col]).norm(dim=-1) edge_attr = self.distance_expansion(edge_weight) for interaction in self.interactions: h = h + interaction(h, edge_index, edge_weight, edge_attr) h = self.lin1(h) h = self.act(h) h = self.lin2(h) if self.dipole: # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) h = h * (pos - c[batch]) if not self.dipole and self.mean is not None and self.std is not None: h = h * self.std + self.mean if not self.dipole and self.atomref is not None: h = h + self.atomref(z) out = scatter(h, batch, dim=0, reduce=self.readout) if self.dipole: out = torch.norm(out, dim=-1, keepdim=True) if self.scale is not None: out = self.scale * out return out def __repr__(self): return (f'{self.__class__.__name__}(' f'hidden_channels={self.hidden_channels}, ' f'num_filters={self.num_filters}, ' f'num_interactions={self.num_interactions}, ' f'num_gaussians={self.num_gaussians}, ' f'cutoff={self.cutoff})')
class MetaPath2Vec(torch.nn.Module): r"""The MetaPath2Vec model from the `"metapath2vec: Scalable Representation Learning for Heterogeneous Networks" <https://ericdongyx.github.io/papers/ KDD17-dong-chawla-swami-metapath2vec.pdf>`_ paper where random walks based on a given :obj:`metapath` are sampled in a heterogeneous graph, and node embeddings are learned via negative sampling optimization. .. note:: For an example of using MetaPath2Vec, see `examples/metapath2vec.py <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ metapath2vec.py>`_. Args: edge_index_dict (dict): Dictionary holding edge indices for each :obj:`(source_node_type, relation_type, target_node_type)` present in the heterogeneous graph. embedding_dim (int): The size of each embedding vector. metapath (list): The metapath described as a list of :obj:`(source_node_type, relation_type, target_node_type)` tuples. walk_length (int): The walk length. context_size (int): The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes. walks_per_node (int, optional): The number of walks to sample for each node. (default: :obj:`1`) num_negative_samples (int, optional): The number of negative samples to use for each positive sample. (default: :obj:`1`) num_nodes_dict (dict, optional): Dictionary holding the number of nodes for each node type. (default: :obj:`None`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the weight matrix will be sparse. (default: :obj:`False`) """ def __init__(self, edge_index_dict, embedding_dim, metapath, walk_length, context_size, walks_per_node=1, num_negative_samples=1, num_nodes_dict=None, sparse=False): super(MetaPath2Vec, self).__init__() if num_nodes_dict is None: num_nodes_dict = {} for keys, edge_index in edge_index_dict.items(): key = keys[0] N = int(edge_index[0].max() + 1) num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) key = keys[-1] N = int(edge_index[1].max() + 1) num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) adj_dict = {} for keys, edge_index in edge_index_dict.items(): sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=sizes) adj = adj.to('cpu') adj_dict[keys] = adj assert metapath[0][0] == metapath[-1][-1] assert walk_length >= context_size self.adj_dict = adj_dict self.embedding_dim = embedding_dim self.metapath = metapath self.walk_length = walk_length self.context_size = context_size self.walks_per_node = walks_per_node self.num_negative_samples = num_negative_samples self.num_nodes_dict = num_nodes_dict types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath]) types = sorted(list(types)) count = 0 self.start, self.end = {}, {} for key in types: self.start[key] = count count += num_nodes_dict[key] self.end[key] = count offset = [self.start[metapath[0][0]]] offset += [self.start[keys[-1]] for keys in metapath ] * int((walk_length / len(metapath)) + 1) offset = offset[:walk_length + 1] assert len(offset) == walk_length + 1 self.offset = torch.tensor(offset) self.embedding = Embedding(count, embedding_dim, sparse=sparse) self.reset_parameters() def reset_parameters(self): self.embedding.reset_parameters() def forward(self, node_type, batch=None): """Returns the embeddings for the nodes in :obj:`subset` of type :obj:`node_type`.""" emb = self.embedding.weight[self.start[node_type]:self.end[node_type]] return emb if batch is None else emb[batch] def loader(self, **kwargs): return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]), collate_fn=self.sample, **kwargs) def pos_sample(self, batch): # device = self.embedding.weight.device batch = batch.repeat(self.walks_per_node) rws = [batch] for i in range(self.walk_length): keys = self.metapath[i % len(self.metapath)] adj = self.adj_dict[keys] batch = adj.sample(num_neighbors=1, subset=batch).squeeze() rws.append(batch) rw = torch.stack(rws, dim=-1) rw.add_(self.offset.view(1, -1)) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) def neg_sample(self, batch): batch = batch.repeat(self.walks_per_node * self.num_negative_samples) rws = [batch] for i in range(self.walk_length): keys = self.metapath[i % len(self.metapath)] batch = torch.randint(0, self.num_nodes_dict[keys[-1]], (batch.size(0), ), dtype=torch.long) rws.append(batch) rw = torch.stack(rws, dim=-1) rw.add_(self.offset.view(1, -1)) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) def sample(self, batch): if not isinstance(batch, torch.Tensor): batch = torch.tensor(batch) return self.pos_sample(batch), self.neg_sample(batch) def loss(self, pos_rw, neg_rw): r"""Computes the loss given positive and negative random walks.""" # Positive loss. start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous() h_start = self.embedding(start).view(pos_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean() # Negative loss. start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous() h_start = self.embedding(start).view(neg_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean() return pos_loss + neg_loss def test(self, train_z, train_y, test_z, test_y, solver='lbfgs', multi_class='auto', *args, **kwargs): r"""Evaluates latent space quality via a logistic regression downstream task.""" from sklearn.linear_model import LogisticRegression clf = LogisticRegression(solver=solver, multi_class=multi_class, *args, **kwargs).fit(train_z.detach().cpu().numpy(), train_y.detach().cpu().numpy()) return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.embedding.weight.size(0), self.embedding.weight.size(1))
class Node2Vec(torch.nn.Module): r"""The Node2Vec model from the `"node2vec: Scalable Feature Learning for Networks" <https://arxiv.org/abs/1607.00653>`_ paper where random walks of length :obj:`walk_length` are sampled in a given graph, and node embeddings are learned via negative sampling optimization. .. note:: For an example of using Node2Vec, see `examples/node2vec.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ node2vec.py>`_. Args: edge_index (LongTensor): The edge indices. embedding_dim (int): The size of each embedding vector. walk_length (int): The walk length. context_size (int): The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes. walks_per_node (int, optional): The number of walks to sample for each node. (default: :obj:`1`) p (float, optional): Likelihood of immediately revisiting a node in the walk. (default: :obj:`1`) q (float, optional): Control parameter to interpolate between breadth-first strategy and depth-first strategy (default: :obj:`1`) num_negative_samples (int, optional): The number of negative samples to use for each positive sample. (default: :obj:`1`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the weight matrix will be sparse. (default: :obj:`False`) """ def __init__(self, edge_index, embedding_dim, walk_length, context_size, walks_per_node=1, p=1, q=1, num_negative_samples=1, num_nodes=None, sparse=False): super().__init__() if random_walk is None: raise ImportError('`Node2Vec` requires `torch-cluster`.') N = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index self.adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) self.adj = self.adj.to('cpu') assert walk_length >= context_size self.embedding_dim = embedding_dim self.walk_length = walk_length - 1 self.context_size = context_size self.walks_per_node = walks_per_node self.p = p self.q = q self.num_negative_samples = num_negative_samples self.embedding = Embedding(N, embedding_dim, sparse=sparse) self.reset_parameters() def reset_parameters(self): self.embedding.reset_parameters() def forward(self, batch=None): """Returns the embeddings for the nodes in :obj:`batch`.""" emb = self.embedding.weight return emb if batch is None else emb.index_select(0, batch) def loader(self, **kwargs): return DataLoader(range(self.adj.sparse_size(0)), collate_fn=self.sample, **kwargs) def pos_sample(self, batch): batch = batch.repeat(self.walks_per_node) rowptr, col, _ = self.adj.csr() rw = random_walk(rowptr, col, batch, self.walk_length, self.p, self.q) if not isinstance(rw, torch.Tensor): rw = rw[0] walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) def neg_sample(self, batch): batch = batch.repeat(self.walks_per_node * self.num_negative_samples) rw = torch.randint(self.adj.sparse_size(0), (batch.size(0), self.walk_length)) rw = torch.cat([batch.view(-1, 1), rw], dim=-1) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) def sample(self, batch): if not isinstance(batch, torch.Tensor): batch = torch.tensor(batch) return self.pos_sample(batch), self.neg_sample(batch) def loss(self, pos_rw, neg_rw): r"""Computes the loss given positive and negative random walks.""" # Positive loss. start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous() h_start = self.embedding(start).view(pos_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean() # Negative loss. start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous() h_start = self.embedding(start).view(neg_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean() return pos_loss + neg_loss def test(self, train_z, train_y, test_z, test_y, solver='lbfgs', multi_class='auto', *args, **kwargs): r"""Evaluates latent space quality via a logistic regression downstream task.""" from sklearn.linear_model import LogisticRegression clf = LogisticRegression(solver=solver, multi_class=multi_class, *args, **kwargs).fit(train_z.detach().cpu().numpy(), train_y.detach().cpu().numpy()) return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.embedding.weight.size(0)}, ' f'{self.embedding.weight.size(1)})')
class MetaPath2Vec(torch.nn.Module): r"""The MetaPath2Vec model from the `"metapath2vec: Scalable Representation Learning for Heterogeneous Networks" <https://ericdongyx.github.io/papers/ KDD17-dong-chawla-swami-metapath2vec.pdf>`_ paper where random walks based on a given :obj:`metapath` are sampled in a heterogeneous graph, and node embeddings are learned via negative sampling optimization. .. note:: For an example of using MetaPath2Vec, see `examples/hetero/metapath2vec.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ hetero/metapath2vec.py>`_. Args: edge_index_dict (Dict[Tuple[str, str, str], Tensor]): Dictionary holding edge indices for each :obj:`(src_node_type, rel_type, dst_node_type)` present in the heterogeneous graph. embedding_dim (int): The size of each embedding vector. metapath (List[Tuple[str, str, str]]): The metapath described as a list of :obj:`(src_node_type, rel_type, dst_node_type)` tuples. walk_length (int): The walk length. context_size (int): The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes. walks_per_node (int, optional): The number of walks to sample for each node. (default: :obj:`1`) num_negative_samples (int, optional): The number of negative samples to use for each positive sample. (default: :obj:`1`) num_nodes_dict (Dict[str, int], optional): Dictionary holding the number of nodes for each node type. (default: :obj:`None`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the weight matrix will be sparse. (default: :obj:`False`) """ def __init__( self, edge_index_dict: Dict[EdgeType, Tensor], embedding_dim: int, metapath: List[EdgeType], walk_length: int, context_size: int, walks_per_node: int = 1, num_negative_samples: int = 1, num_nodes_dict: Optional[Dict[NodeType, int]] = None, sparse: bool = False, ): super().__init__() if num_nodes_dict is None: num_nodes_dict = {} for keys, edge_index in edge_index_dict.items(): key = keys[0] N = int(edge_index[0].max() + 1) num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) key = keys[-1] N = int(edge_index[1].max() + 1) num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) adj_dict = {} for keys, edge_index in edge_index_dict.items(): sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=sizes) adj = adj.to('cpu') adj_dict[keys] = adj assert walk_length + 1 >= context_size if walk_length > len(metapath) and metapath[0][0] != metapath[-1][-1]: raise AttributeError( "The 'walk_length' is longer than the given 'metapath', but " "the 'metapath' does not denote a cycle") self.adj_dict = adj_dict self.embedding_dim = embedding_dim self.metapath = metapath self.walk_length = walk_length self.context_size = context_size self.walks_per_node = walks_per_node self.num_negative_samples = num_negative_samples self.num_nodes_dict = num_nodes_dict types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath]) types = sorted(list(types)) count = 0 self.start, self.end = {}, {} for key in types: self.start[key] = count count += num_nodes_dict[key] self.end[key] = count offset = [self.start[metapath[0][0]]] offset += [self.start[keys[-1]] for keys in metapath ] * int((walk_length / len(metapath)) + 1) offset = offset[:walk_length + 1] assert len(offset) == walk_length + 1 self.offset = torch.tensor(offset) # + 1 denotes a dummy node used to link to for isolated nodes. self.embedding = Embedding(count + 1, embedding_dim, sparse=sparse) self.dummy_idx = count self.reset_parameters() def reset_parameters(self): self.embedding.reset_parameters() def forward(self, node_type: str, batch: OptTensor = None) -> Tensor: r"""Returns the embeddings for the nodes in :obj:`batch` of type :obj:`node_type`.""" emb = self.embedding.weight[self.start[node_type]:self.end[node_type]] return emb if batch is None else emb.index_select(0, batch) def loader(self, **kwargs): r"""Returns the data loader that creates both positive and negative random walks on the heterogeneous graph. Args: **kwargs (optional): Arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]), collate_fn=self._sample, **kwargs) def _pos_sample(self, batch: Tensor) -> Tensor: batch = batch.repeat(self.walks_per_node) rws = [batch] for i in range(self.walk_length): keys = self.metapath[i % len(self.metapath)] adj = self.adj_dict[keys] batch = sample(adj, batch, num_neighbors=1, dummy_idx=self.dummy_idx).view(-1) rws.append(batch) rw = torch.stack(rws, dim=-1) rw.add_(self.offset.view(1, -1)) rw[rw > self.dummy_idx] = self.dummy_idx walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) def _neg_sample(self, batch: Tensor) -> Tensor: batch = batch.repeat(self.walks_per_node * self.num_negative_samples) rws = [batch] for i in range(self.walk_length): keys = self.metapath[i % len(self.metapath)] batch = torch.randint(0, self.num_nodes_dict[keys[-1]], (batch.size(0), ), dtype=torch.long) rws.append(batch) rw = torch.stack(rws, dim=-1) rw.add_(self.offset.view(1, -1)) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) def _sample(self, batch: List[int]) -> Tuple[Tensor, Tensor]: if not isinstance(batch, Tensor): batch = torch.tensor(batch, dtype=torch.long) return self._pos_sample(batch), self._neg_sample(batch) def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor: r"""Computes the loss given positive and negative random walks.""" # Positive loss. start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous() h_start = self.embedding(start).view(pos_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean() # Negative loss. start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous() h_start = self.embedding(start).view(neg_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean() return pos_loss + neg_loss def test(self, train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = "lbfgs", multi_class: str = "auto", *args, **kwargs) -> float: r"""Evaluates latent space quality via a logistic regression downstream task.""" from sklearn.linear_model import LogisticRegression clf = LogisticRegression(solver=solver, multi_class=multi_class, *args, **kwargs).fit(train_z.detach().cpu().numpy(), train_y.detach().cpu().numpy()) return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'{self.embedding.weight.size(0) - 1}, ' f'{self.embedding.weight.size(1)})')
class HEATConv(MessagePassing): r"""The heterogeneous edge-enhanced graph attentional operator from the `"Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction" <https://arxiv.org/abs/2106.07161>`_ paper, which enhances :class:`~torch_geometric.nn.conv.GATConv` by: 1. type-specific transformations of nodes of different types 2. edge type and edge feature incorporation, in which edges are assumed to have different types but contain the same kind of attributes Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. num_node_types (int): The number of node types. num_edge_types (int): The number of edge types. edge_type_emb_dim (int): The embedding size of edge types. edge_dim (int): Edge feature dimensionality. edge_attr_emb_dim (int): The embedding size of edge features. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add transformed root node features to the output. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels: int, out_channels: int, num_node_types: int, num_edge_types: int, edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, root_weight: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.root_weight = root_weight self.hetero_lin = HeteroLinear(in_channels, out_channels, num_node_types, bias=bias) self.edge_type_emb = Embedding(num_edge_types, edge_type_emb_dim) self.edge_attr_emb = Linear(edge_dim, edge_attr_emb_dim, bias=False) self.att = Linear(2 * out_channels + edge_type_emb_dim + edge_attr_emb_dim, self.heads, bias=False) self.lin = Linear(out_channels + edge_attr_emb_dim, out_channels, bias=bias) self.reset_parameters() def reset_parameters(self): self.hetero_lin.reset_parameters() self.edge_type_emb.reset_parameters() self.edge_attr_emb.reset_parameters() self.att.reset_parameters() self.lin.reset_parameters() def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor, edge_type: Tensor, edge_attr: OptTensor = None, size: Size = None) -> Tensor: """""" x = self.hetero_lin(x, node_type) edge_type_emb = F.leaky_relu(self.edge_type_emb(edge_type), self.negative_slope) # propagate_type: (x: Tensor, edge_type_emb: Tensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb, edge_attr=edge_attr, size=size) if self.concat: if self.root_weight: out += x.view(-1, 1, self.out_channels) out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.root_weight: out += x return out def message(self, x_i: Tensor, x_j: Tensor, edge_type_emb: Tensor, edge_attr: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: edge_attr = F.leaky_relu(self.edge_attr_emb(edge_attr), self.negative_slope) alpha = torch.cat([x_i, x_j, edge_type_emb, edge_attr], dim=-1) alpha = F.leaky_relu(self.att(alpha), self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) out = self.lin(torch.cat([x_j, edge_attr], dim=-1)).unsqueeze(-2) return out * alpha.unsqueeze(-1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})')
class Net(torch.nn.Module): def __init__(self, num_classes, gnn_layers, embed_dim, hidden_dim, jk_layer, process_step, dropout): super(Net, self).__init__() self.dropout = dropout self.convs = torch.nn.ModuleList() self.embedding = Embedding(6, embed_dim) for i in range(gnn_layers): if i == 0: self.convs.append( AGGINConv(Sequential(Linear(2 * embed_dim + 2, hidden_dim), ReLU(), Linear(hidden_dim, hidden_dim), ReLU(), BN(hidden_dim)), train_eps=True)) else: self.convs.append( AGGINConv(Sequential(Linear(hidden_dim, hidden_dim), ReLU(), Linear(hidden_dim, hidden_dim), ReLU(), BN(hidden_dim)), train_eps=True)) if jk_layer.isdigit(): jk_layer = int(jk_layer) self.jk = JumpingKnowledge(mode='lstm', channels=hidden_dim, gnn_layers=jk_layer) self.s2s = (Set2Set(hidden_dim, processing_steps=process_step)) self.fc1 = Linear(2 * hidden_dim, hidden_dim) self.fc2 = Linear(hidden_dim, int(hidden_dim / 2)) self.fc3 = Linear(int(hidden_dim / 2), num_classes) elif jk_layer == 'cat': self.jk = JumpingKnowledge(mode=jk_layer) self.s2s = (Set2Set(gnn_layers * hidden_dim, processing_steps=process_step)) self.fc1 = Linear(2 * gnn_layers * hidden_dim, hidden_dim) self.fc2 = Linear(hidden_dim, int(hidden_dim / 2)) self.fc3 = Linear(int(hidden_dim / 2), num_classes) elif jk_layer == 'max': self.jk = JumpingKnowledge(mode=jk_layer) self.s2s = (Set2Set(hidden_dim, processing_steps=process_step)) self.fc1 = Linear(2 * hidden_dim, hidden_dim) self.fc2 = Linear(hidden_dim, int(hidden_dim / 2)) self.fc3 = Linear(int(hidden_dim / 2), num_classes) def reset_parameters(self): self.embedding.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jk.reset_parameters() self.s2s.reset_parameters() self.fc1.reset_parameters() self.fc2.reset_parameters() self.fc3.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch # Embedding the categorical values from Gene expression and Node type xc = x[:, :2].type(torch.long) ems = self.embedding(xc) ems = ems.view(-1, ems.shape[1] * ems.shape[2]) x = torch.cat((ems, x[:, 2:]), dim=1) xs = [] for i, conv in enumerate(self.convs): x = conv(x, edge_index) xs += [x] x = self.jk(xs) x = self.s2s(x, batch) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) logits = self.fc3(x) return logits
class Net(torch.nn.Module): def __init__(self, hidden_channels, out_channels, num_layers, dropout=0.0, inter_message_passing=True): super(Net, self).__init__() self.num_layers = num_layers self.dropout = dropout self.inter_message_passing = inter_message_passing self.atom_encoder = AtomEncoder(hidden_channels) self.clique_encoder = Embedding(4, hidden_channels) self.bond_encoders = ModuleList() self.atom_convs = ModuleList() self.atom_batch_norms = ModuleList() for _ in range(num_layers): self.bond_encoders.append(BondEncoder(hidden_channels)) nn = Sequential( Linear(hidden_channels, 2 * hidden_channels), BatchNorm1d(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), ) self.atom_convs.append(GINEConv(nn, train_eps=True)) self.atom_batch_norms.append(BatchNorm1d(hidden_channels)) self.clique_convs = ModuleList() self.clique_batch_norms = ModuleList() for _ in range(num_layers): nn = Sequential( Linear(hidden_channels, 2 * hidden_channels), BatchNorm1d(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), ) self.clique_convs.append(GINConv(nn, train_eps=True)) self.clique_batch_norms.append(BatchNorm1d(hidden_channels)) self.atom2clique_lins = ModuleList() self.clique2atom_lins = ModuleList() for _ in range(num_layers): self.atom2clique_lins.append( Linear(hidden_channels, hidden_channels)) self.clique2atom_lins.append( Linear(hidden_channels, hidden_channels)) self.atom_lin = Linear(hidden_channels, hidden_channels) self.clique_lin = Linear(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels, out_channels) def reset_parameters(self): self.atom_encoder.reset_parameters() self.clique_encoder.reset_parameters() for emb, conv, batch_norm in zip(self.bond_encoders, self.atom_convs, self.atom_batch_norms): emb.reset_parameters() conv.reset_parameters() batch_norm.reset_parameters() for conv, batch_norm in zip(self.clique_convs, self.clique_batch_norms): conv.reset_parameters() batch_norm.reset_parameters() for lin1, lin2 in zip(self.atom2clique_lins, self.clique2atom_lins): lin1.reset_parameters() lin2.reset_parameters() self.atom_lin.reset_parameters() self.clique_lin.reset_parameters() self.lin.reset_parameters() def forward(self, data): x = self.atom_encoder(data.x.squeeze()) if self.inter_message_passing: x_clique = self.clique_encoder(data.x_clique.squeeze()) for i in range(self.num_layers): edge_attr = self.bond_encoders[i](data.edge_attr) x = self.atom_convs[i](x, data.edge_index, edge_attr) x = self.atom_batch_norms[i](x) x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) if self.inter_message_passing: row, col = data.atom2clique_index x_clique = x_clique + F.relu(self.atom2clique_lins[i](scatter( x[row], col, dim=0, dim_size=x_clique.size(0), reduce='mean'))) x_clique = self.clique_convs[i](x_clique, data.tree_edge_index) x_clique = self.clique_batch_norms[i](x_clique) x_clique = F.relu(x_clique) x_clique = F.dropout(x_clique, self.dropout, training=self.training) x = x + F.relu(self.clique2atom_lins[i](scatter( x_clique[col], row, dim=0, dim_size=x.size(0), reduce='mean'))) x = scatter(x, data.batch, dim=0, reduce='mean') x = F.dropout(x, self.dropout, training=self.training) x = self.atom_lin(x) if self.inter_message_passing: tree_batch = torch.repeat_interleave(data.num_cliques) x_clique = scatter(x_clique, tree_batch, dim=0, dim_size=x.size(0), reduce='mean') x_clique = F.dropout(x_clique, self.dropout, training=self.training) x_clique = self.clique_lin(x_clique) x = x + x_clique x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) x = self.lin(x) return x
class LightGCN(torch.nn.Module): r"""The LightGCN model from the `"LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation" <https://arxiv.org/abs/2002.02126>`_ paper. :class:`~torch_geometric.nn.models.LightGCN` learns embeddings by linearly propagating them on the underlying graph, and uses the weighted sum of the embeddings learned at all layers as the final embedding .. math:: \textbf{x}_i = \sum_{l=0}^{L} \alpha_l \textbf{x}^{(l)}_i, where each layer's embedding is computed as .. math:: \mathbf{x}^{(l+1)}_i = \sum_{j \in \mathcal{N}(i)} \frac{1}{\sqrt{\deg(i)\deg(j)}}\mathbf{x}^{(l)}_j. Two prediction heads and trainign objectives are provided: **link prediction** (via :meth:`~torch_geometric.nn.models.LightGCN.link_pred_loss` and :meth:`~torch_geometric.nn.models.LightGCN.predict_link`) and **recommendation** (via :meth:`~torch_geometric.nn.models.LightGCN.recommendation_loss` and :meth:`~torch_geometric.nn.models.LightGCN.recommend`). .. note:: Embeddings are propagated according to the graph connectivity specified by :obj:`edge_index` while rankings or link probabilities are computed according to the edges specified by :obj:`edge_label_index`. Args: num_nodes (int): The number of nodes in the graph. embedding_dim (int): The dimensionality of node embeddings. num_layers (int): The number of :class:`~torch_geometric.nn.conv.LGConv` layers. alpha (float or Tensor, optional): The scalar or vector specifying the re-weighting coefficients for aggregating the final embedding. If set to :obj:`None`, the uniform initialization of :obj:`1 / (num_layers + 1)` is used. (default: :obj:`None`) **kwargs (optional): Additional arguments of the underlying :class:`~torch_geometric.nn.conv.LGConv` layers. """ def __init__( self, num_nodes: int, embedding_dim: int, num_layers: int, alpha: Optional[Union[float, Tensor]] = None, **kwargs, ): super().__init__() self.num_nodes = num_nodes self.embedding_dim = embedding_dim self.num_layers = num_layers if alpha is None: alpha = 1. / (num_layers + 1) if isinstance(alpha, Tensor): assert alpha.size(0) == num_layers + 1 else: alpha = torch.tensor([alpha] * (num_layers + 1)) self.register_buffer('alpha', alpha) self.embedding = Embedding(num_nodes, embedding_dim) self.convs = ModuleList([LGConv(**kwargs) for _ in range(num_layers)]) self.reset_parameters() def reset_parameters(self): self.embedding.reset_parameters() for conv in self.convs: conv.reset_parameters() def get_embedding(self, edge_index: Adj) -> Tensor: x = self.embedding.weight out = x * self.alpha[0] for i in range(self.num_layers): x = self.convs[i](x, edge_index) out = out + x * self.alpha[i + 1] return out def forward(self, edge_index: Adj, edge_label_index: OptTensor = None) -> Tensor: r"""Computes rankings for pairs of nodes. Args: edge_index (Tensor or SparseTensor): Edge tensor specifying the connectivity of the graph. edge_label_index (Tensor, optional): Edge tensor specifying the node pairs for which to compute rankings or probabilities. If :obj:`edge_label_index` is set to :obj:`None`, all edges in :obj:`edge_index` will be used instead. (default: :obj:`None`) """ if edge_label_index is None: edge_label_index = edge_index out = self.get_embedding(edge_index) out_src = out[edge_label_index[0]] out_dst = out[edge_label_index[1]] return (out_src * out_dst).sum(dim=-1) def predict_link(self, edge_index: Adj, edge_label_index: OptTensor = None, prob: bool = False) -> Tensor: r"""Predict links between nodes specified in :obj:`edge_label_index`. Args: prob (bool): Whether probabilities should be returned. (default: :obj:`False`) """ pred = self(edge_index, edge_label_index).sigmoid() return pred if prob else pred.round() def recommend(self, edge_index: Adj, src_index: OptTensor = None, dst_index: OptTensor = None, k: int = 1) -> Tensor: r"""Get top-:math:`k` recommendations for nodes in :obj:`src_index`. Args: src_index (Tensor, optional): Node indices for which recommendations should be generated. If set to :obj:`None`, all nodes will be used. (default: :obj:`None`) dst_index (Tensor, optional): Node indices which represent the possible recommendation choices. If set to :obj:`None`, all nodes will be used. (default: :obj:`None`) k (int, optional): Number of recommendations. (default: :obj:`1`) """ out_src = out_dst = self.get_embedding(edge_index) if src_index is not None: out_src = out_src[src_index] if dst_index is not None: out_dst = out_dst[dst_index] pred = out_src @ out_dst.t() top_index = pred.topk(k, dim=-1).indices if dst_index is not None: # Map local top-indices to original indices. top_index = dst_index[top_index.view(-1)].view(*top_index.size()) return top_index def link_pred_loss(self, pred: Tensor, edge_label: Tensor, **kwargs) -> Tensor: r"""Computes the model loss for a link prediction objective via the :class:`torch.nn.BCEWithLogitsLoss`. Args: pred (Tensor): The predictions. edge_label (Tensor): The ground-truth edge labels. **kwargs (optional): Additional arguments of the underlying :class:`torch.nn.BCEWithLogitsLoss` loss function. """ loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs) return loss_fn(pred, edge_label.to(pred.dtype)) def recommendation_loss(self, pos_edge_rank: Tensor, neg_edge_rank: Tensor, lambda_reg: float = 1e-4, **kwargs) -> Tensor: r"""Computes the model loss for a ranking objective via the Bayesian Personalized Ranking (BPR) loss. .. note:: The i-th entry in the :obj:`pos_edge_rank` vector and i-th entry in the :obj:`neg_edge_rank` entry must correspond to ranks of positive and negative edges of the same entity (*e.g.*, user). Args: pos_edge_rank (Tensor): Positive edge rankings. neg_edge_rank (Tensor): Negative edge rankings. lambda_reg (int, optional): The :math:`L_2` regularization strength of the Bayesian Personalized Ranking (BPR) loss. (default: 1e-4) **kwargs (optional): Additional arguments of the underlying :class:`torch_geometric.nn.models.lightgcn.BPRLoss` loss function. """ loss_fn = BPRLoss(lambda_reg, **kwargs) return loss_fn(pos_edge_rank, neg_edge_rank, self.embedding.weight) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.num_nodes}, ' f'{self.embedding_dim}, num_layers={self.num_layers})')
class SchNet(torch.nn.Module): r""" The re-implementation for SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" <https://arxiv.org/abs/1706.08566>`_ paper under the 3DGN gramework from `"Spherical Message Passing for 3D Graph Networks" <https://arxiv.org/abs/2102.05013v2>`_ paper. Args: energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the negative of the derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`) num_layers (int, optional): The number of layers. (default: :obj:`6`) hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`) num_filters (int, optional): The number of filters to use. (default: :obj:`128`) num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`) cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`). """ def __init__(self, energy_and_force=False, cutoff=10.0, num_layers=6, hidden_channels=128, num_filters=128, num_gaussians=50): super(SchNet, self).__init__() self.energy_and_force = energy_and_force self.cutoff = cutoff self.num_layers = num_layers self.hidden_channels = hidden_channels self.num_filters = num_filters self.num_gaussians = num_gaussians self.init_v = Embedding(100, hidden_channels) self.dist_emb = emb(0.0, cutoff, num_gaussians) self.update_vs = torch.nn.ModuleList([ update_v(hidden_channels, num_filters) for _ in range(num_layers) ]) self.update_es = torch.nn.ModuleList([ update_e(hidden_channels, num_filters, num_gaussians, cutoff) for _ in range(num_layers) ]) self.update_u = update_u(hidden_channels) self.reset_parameters() def reset_parameters(self): self.init_v.reset_parameters() for update_e in self.update_es: update_e.reset_parameters() for update_v in self.update_vs: update_v.reset_parameters() self.update_u.reset_parameters() def forward(self, batch_data): z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch if self.energy_and_force: pos.requires_grad_() edge_index = radius_graph(pos, r=self.cutoff, batch=batch) row, col = edge_index dist = (pos[row] - pos[col]).norm(dim=-1) dist_emb = self.dist_emb(dist) v = self.init_v(z) for update_e, update_v in zip(self.update_es, self.update_vs): e = update_e(v, dist, dist_emb, edge_index) v = update_v(v, e, edge_index) u = self.update_u(v, batch) return u
class BaseCrystalModel(Module): """ Base model for crystal problem. """ def __init__(self, num_node_features=1, num_edge_features=3, num_state_features=0, num_node_hidden_channels=128, num_node_interaction_channels=128, num_interactions=1, num_edge_gaussians=None, num_node_embeddings=120, cutoff=10.0, out_size=1, readout='add', mean=None, std=None, norm=False, atom_ref=None, simple_z=True, interactions=None, readout_layer=None, add_state=False, **kwargs): """ Args: num_node_features: (int) input number of node feature (atom feature). num_edge_features: (int) input number of bond feature. if ``num_edge_gaussians`` offered, this parameter is neglect. num_state_features: (int) input number of state feature. num_node_embeddings: (int) number of embeddings, For generate the initial embedding matrix to on behalf of node feature. num_node_hidden_channels: (int) num_node_hidden_channels for node feature. num_node_interaction_channels: (int) channels for node feature. num_interactions: (int) conv number. num_edge_gaussians: (int) number of gaussian Smearing number for radius. deprecated keep this compact with your bond data. cutoff: (float) cutoff for calculate neighbor bond readout: (str) Merge node method. such as "add","mean","max","mean". mean: (float) mean std: (float) std norm:(bool) False or True atom_ref: (torch.tensor shape (120,1)) properties for atom. such as target y is volumes of compound, atom_ref could be the atom volumes of all atom (H,H,He,Li,...). And you could copy the first term to make sure the `H` index start form 1. simple_z: (bool,str) just used "z" or used "x" to calculate. interactions: (Callable) torch module for interactions dynamically: pass the torch module to interactions parameter.static: re-define the ``get_interactions_layer`` and keep this parameter is None. the forward input is (h, edge_index, edge_weight, edge_attr, data=data) readout_layer: (Callable) torch module for interactions dynamically: pass the torch module to interactions parameter. static: re-define the ``get_interactions_layer`` and keep this parameter is None. the forward input is (out,) add_state: (bool) add state attribute before output. out_size:(int) number of out size. for regression,is 1 and for classification should be defined. """ super(BaseCrystalModel, self).__init__() self.interaction_kwargs = {} for k, v in kwargs.items(): if "interaction_kwargs_" in k: self.interaction_kwargs[k.replace("interaction_kwargs_", "")] = v self.readout_kwargs = {} for k, v in kwargs.items(): if "readout_kwargs_" in k: self.readout_kwargs[k.replace("readout_kwargs_", "")] = v # 初始定义 if num_edge_gaussians is None: num_edge_gaussians = num_edge_features assert readout in ['add', 'sum', 'min', 'mean', "max"] self.num_node_hidden_channels = num_node_hidden_channels self.num_state_features = num_state_features self.num_node_interaction_channels = num_node_interaction_channels self.num_interactions = num_interactions self.num_edge_gaussians = num_edge_gaussians self.cutoff = cutoff self.readout = readout self.mean = mean self.std = std self.scale = None self.simple_z = simple_z self.interactions = interactions self.readout_layer = readout_layer self.out_size = out_size self.norm = norm # 嵌入原子属性,备用 # (嵌入别太多,容易慢,大多数情况下用不到。) atomic_mass = torch.from_numpy(ase_data.atomic_masses) # 嵌入原子质量 covalent_radii = torch.from_numpy(ase_data.covalent_radii) # 嵌入共价半径 self.register_buffer('atomic_mass', atomic_mass) self.register_buffer('atomic_radii', covalent_radii) # 缓冲buffer必须要登记注册才会有效,如果仅仅将张量赋值给Module模块的属性,不会被自动转为缓冲buffer. # 因而也无法被state_dict()、buffers()、named_buffers()访问到。 # 定义输入 # 使用原子性质,或者使用Embedding 产生随机数据。 # 使用键性质,或者使用Embedding 产生随机数据。 if num_node_embeddings < 120: print( "default, num_node_embeddings>=120,if you want simple the net work and " "This network does not apply to other elements, the num_node_embeddings could be less but large than " "the element type number in your data.") # 原子个数,一般不用动,这是所有原子种类数, # 一般来说,采用embedding的网络, # 在向其他元素(训练集中没有的)数据推广的能力较差。 if simple_z is True: if num_node_features != 0: warnings.warn( "simple_z just accept num_node_features == 0, " "and don't use your self-defined 'x' data, but element number Z", UserWarning) self.embedding_e = Embedding(num_node_embeddings, num_node_hidden_channels) # self.embedding_l = Linear(2, 2) # not used # self.embedding_l2 = Linear(2, 2) # not used elif self.simple_z == "no_embed": # self.embedding_e = Linear(2, 2) self.embedding_l = Linear(num_node_features, num_node_hidden_channels) self.embedding_l2 = Linear(num_node_hidden_channels, num_node_hidden_channels) else: assert num_node_features > 0, "The `num_node_features` must be the same size with `x` feature." self.embedding_e = Embedding(num_node_embeddings, num_node_hidden_channels) self.embedding_l = Linear(num_node_features, num_node_hidden_channels) self.embedding_l2 = Linear(num_node_hidden_channels, num_node_hidden_channels) self.bn = BatchNorm1d(num_node_hidden_channels) # 交互层 需要自定义 get_interactions_layer if interactions is None: self.get_interactions_layer() elif isinstance(interactions, ModuleList): self.get_res_interactions_layer(interactions) elif isinstance(interactions, Module): self.interactions = interactions else: raise NotImplementedError( "please implement get_interactions_layer function, " "or pass interactions parameters.") # 合并层 需要自定义 if readout_layer is None: self.get_readout_layer() elif isinstance(interactions, Module): self.readout_layer = readout_layer else: raise NotImplementedError( "please implement get_readout_layer function, " "or pass readout_layer parameters.") self.register_buffer('initial_atom_ref', atom_ref) if atom_ref is None: self.atom_ref = atom_ref elif isinstance(atom_ref, Tensor) and atom_ref.shape[0] == 120: self.atom_ref = lambda x: atom_ref[x] elif atom_ref == "embed": self.atom_ref = Embedding(120, 1) self.atom_ref.weight.data.copy_(atom_ref) # 单个原子的性质是否加到最后 else: self.atom_ref = atom_ref self.add_state = add_state if self.add_state: self.dp = LayerNorm(self.num_state_features) self.ads = Linear(self.num_state_features, 1) self.ads2 = Linear(1, self.out_size) self.reset_parameters() def forward(self, data): # 使用embedding 作为假的原子特征输入,而没有用原子特征输入 assert hasattr(data, "z") assert hasattr(data, "pos") assert hasattr(data, "batch") z = data.z batch = data.batch # pos = data.pos # 处理数据阶段 if self.simple_z is True: # 处理数据阶段 assert z.dim() == 1 and z.dtype == torch.long h = self.embedding_e(z) h = F.softplus(h) elif self.simple_z == "no_embed": assert hasattr(data, "x") x = data.x h = F.softplus(self.embedding_l(x)) h = self.embedding_l2(h) else: assert hasattr(data, "x") x = data.x h1 = self.embedding_e(z) x = F.softplus(self.embedding_l(x)) h2 = self.embedding_l2(x) h = h1 + h2 data.x = h if data.edge_index is None: edge_index = data.adj_t else: edge_index = data.edge_index if not hasattr(data, "edge_weight") or data.edge_weight is None: if not hasattr(data, "edge_attr") or data.edge_attr is None: raise NotImplementedError( "Must offer edge_weight or edge_attr.") else: if data.edge_attr.shape[1] == 1: data.edge_weight = data.edge_attr.reshape(-1) else: data.edge_weight = torch.norm(data.edge_attr, dim=1, keepdim=True) h = self.bn(h) h = self.interactions(h, edge_index, data.edge_weight, data.edge_attr, data=data) h = self.internal_forward(h, z) out = self.readout_layer(h, batch) if self.add_state: assert hasattr( data, "state_attr" ), "the ``add_state`` must accept ``state_attr`` in data." sta = self.dp(data.state_attr) sta = self.ads(sta) sta = F.relu(sta) out = self.ads2(out.expand_as(sta) + sta) return self.output_forward(out) def get_res_interactions_layer(self, interactions): self.interactions = GeoResNet(interactions) @abstractmethod def get_interactions_layer(self): """This part shloud re-defined. And must add the ``interactions`` attribute. Examples:: >>> ... >>> self.interactions = YourNet() """ def get_readout_layer(self): """This part shloud re-defined. And must add the ``readout_layer`` attribute. Examples:: >>> self.readout_layer = torch.nn.Sequential(...) Examples:: >>> ... >>> self.readout_layer = YourNet() """ if "readout_kwargs_layers_size" in self.readout_kwargs: self.readout_layer = GeneralReadOutLayer(**self.readout_kwargs) else: self.readout_layer = GeneralReadOutLayer([ self.num_node_interaction_channels, self.readout, 200, self.out_size ], **self.readout_kwargs) def output_forward(self, out): if self.mean is not None and self.std is not None: out = out * self.std + self.mean if self.norm is True: out = torch.norm(out, dim=-1, keepdim=True) if self.scale is not None: out = self.scale * out return out.view(-1) def dipole_forward(self, h, z, pos, batch): # 加入偶极矩 # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) c = segment_csr(mass * pos, get_ptr(batch)) / segment_csr( mass, get_ptr(batch)) h = h * (pos - c[batch]) return h def internal_forward(self, h, z): if self.atom_ref is not None: h = h + self.atom_ref(z) return h def __repr__(self): return ( f'{self.__class__.__name__}(' f'num_node_hidden_channels={self.num_node_hidden_channels}, ' f'num_node_interaction_channels={self.num_node_interaction_channels}, ' f'num_interactions={self.num_interactions}, ' f'num_edge_gaussians={self.num_edge_gaussians}, ' f'cutoff={self.cutoff})') def reset_parameters(self): if hasattr(self, "embedding_e"): self.embedding_e.reset_parameters() if hasattr(self, "embedding_l"): self.embedding_l.reset_parameters() if hasattr(self, "embedding_l2"): self.embedding_l2.reset_parameters() self.bn.reset_parameters() if self.atom_ref is not None: self.atom_ref.weight.data.copy_(self.initial_atomref)