예제 #1
0
파일: model.py 프로젝트: zhy5186612/DIG
class schnet(torch.nn.Module):
    def __init__(self, energy_and_force, cutoff=10.0, num_layers=6, hidden_channels=128, num_filters=128, num_gaussians=50):
        super(schnet, self).__init__()

        self.energy_and_force = energy_and_force
        self.cutoff = cutoff
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels
        self.num_filters = num_filters
        self.num_gaussians = num_gaussians

        self.init_v = Embedding(100, hidden_channels)
        self.dist_emb = emb(0.0, cutoff, num_gaussians)

        self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters) for _ in range(num_layers)])

        self.update_es = torch.nn.ModuleList([
            update_e(hidden_channels, num_filters, num_gaussians, cutoff) for _ in range(num_layers)])
        
        self.update_u = update_u(hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.init_v.reset_parameters()
        for update_e in self.update_es:
            update_e.reset_parameters()
        for update_v in self.update_vs:
            update_v.reset_parameters()
        self.update_u.reset_parameters()

    def forward(self, batch_data):
        z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch
        if self.energy_and_force:
            pos.requires_grad_()

        edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
        row, col = edge_index
        dist = (pos[row] - pos[col]).norm(dim=-1)
        dist_emb = self.dist_emb(dist)

        v = self.init_v(z)

        for update_e, update_v in zip(self.update_es, self.update_vs):
            e = update_e(v, dist, dist_emb, edge_index)
            v = update_v(v,e, edge_index)
        u = self.update_u(v, batch)

        return u
예제 #2
0
class SpanRepresentation(Module):
    def __init__(self, config, d_output, vocab):
        super(SpanRepresentation, self).__init__()
        self.config = config
        self.vocab = vocab
        n_input = len(vocab)
        self.embedding = Embedding(n_input, config.d_embed)
        self.normalize_pretrained = getattr(config, 'normalize_pretrained',
                                            False)

        self.contextualizer = LSTMContextualizer(
            config) if config.n_lstm_layers > 0 else lambda x: x
        self.dropout = Dropout(p=config.dropout)
        self.head_attention = Sequential(self.dropout,
                                         Linear(2 * config.d_lstm_hidden, 1))
        self.head_transform = Sequential(
            self.dropout, Linear(2 * config.d_lstm_hidden, d_output))
        self.init()

    def init(self):
        [xavier_normal(p) for p in self.parameters() if len(p.size()) > 1]
        if self.vocab.vectors is not None:
            pretrained = normalize(
                self.vocab.vectors,
                dim=-1) if self.normalize_pretrained else self.vocab.vectors
            self.embedding.weight.data.copy_(pretrained)
            print(
                'Copied pretrained vectors into relation span representation')
        else:
            #xavier_normal(self.embedding.weight.data)
            self.embedding.reset_parameters()

    def forward(self, inputs):
        text, mask = inputs
        text = self.dropout(self.embedding(text))
        text = self.contextualizer(text)
        weights = masked_softmax(
            self.head_attention(text).squeeze(-1), mask.float())
        representation = (weights.unsqueeze(2) *
                          self.head_transform(text)).sum(dim=1)
        return representation
예제 #3
0
class EmbGCN(torch.nn.Module):

    def __init__(self, num_layers=2, hidden=32, emb_dim=64, num_class=2, num_nodes=None, **kwargs):
        super(EmbGCN, self).__init__()
        hidden = max(hidden, num_class * 2)
        self.conv1 = GCNConv(emb_dim, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(hidden, hidden))
        self.lin2 = Linear(hidden, num_class)
        self.emb = Embedding(num_nodes, emb_dim)
        self.first_lin = Linear(emb_dim, hidden)

    def reset_parameters(self):
        self.first_lin.reset_parameters()
        self.emb.reset_parameters()
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight, node_index = data.x, data.edge_index, data.edge_weight, data.node_index
        x = self.emb(node_index)
        x1 = F.elu(self.conv1(x, edge_index, edge_weight=edge_weight))
        x = F.elu(self.first_lin(x))
        x = F.dropout(x, p=0.5, training=self.training)
        for conv in self.convs:
            x = F.elu(conv(x, edge_index, edge_weight=edge_weight))
        x = x1 + x
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
예제 #4
0
class SchNet(torch.nn.Module):
    r"""The continuous-filter convolutional neural network SchNet from the
    `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling
    Quantum Interactions" <https://arxiv.org/abs/1706.08566>`_ paper that uses
    the interactions blocks of the form

    .. math::
        \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot
        h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),

    here :math:`h_{\mathbf{\Theta}}` denotes an MLP and
    :math:`\mathbf{e}_{j,i}` denotes the interatomic distances between atoms.

    .. note::

        For an example of using a pretrained SchNet variant, see
        `examples/qm9_pretrained_schnet.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        qm9_pretrained_schnet.py>`_.

    Args:
        hidden_channels (int, optional): Hidden embedding size.
            (default: :obj:`128`)
        num_filters (int, optional): The number of filters to use.
            (default: :obj:`128`)
        num_interactions (int, optional): The number of interaction blocks.
            (default: :obj:`6`)
        num_gaussians (int, optional): The number of gaussians :math:`\mu`.
            (default: :obj:`50`)
        cutoff (float, optional): Cutoff distance for interatomic interactions.
            (default: :obj:`10.0`)
        max_num_neighbors (int, optional): The maximum number of neighbors to
            collect for each node within the :attr:`cutoff` distance.
            (default: :obj:`32`)
        readout (string, optional): Whether to apply :obj:`"add"` or
            :obj:`"mean"` global aggregation. (default: :obj:`"add"`)
        dipole (bool, optional): If set to :obj:`True`, will use the magnitude
            of the dipole moment to make the final prediction, *e.g.*, for
            target 0 of :class:`torch_geometric.datasets.QM9`.
            (default: :obj:`False`)
        mean (float, optional): The mean of the property to predict.
            (default: :obj:`None`)
        std (float, optional): The standard deviation of the property to
            predict. (default: :obj:`None`)
        atomref (torch.Tensor, optional): The reference of single-atom
            properties.
            Expects a vector of shape :obj:`(max_atomic_number, )`.
    """

    url = 'http://www.quantum-machine.org/datasets/trained_schnet_models.zip'

    def __init__(self,
                 hidden_channels: int = 128,
                 num_filters: int = 128,
                 num_interactions: int = 6,
                 num_gaussians: int = 50,
                 cutoff: float = 10.0,
                 max_num_neighbors: int = 32,
                 readout: str = 'add',
                 dipole: bool = False,
                 mean: Optional[float] = None,
                 std: Optional[float] = None,
                 atomref: Optional[torch.Tensor] = None):
        super(SchNet, self).__init__()

        import ase

        self.hidden_channels = hidden_channels
        self.num_filters = num_filters
        self.num_interactions = num_interactions
        self.num_gaussians = num_gaussians
        self.cutoff = cutoff
        self.max_num_neighbors = max_num_neighbors
        self.readout = readout
        self.dipole = dipole
        self.readout = 'add' if self.dipole else self.readout
        self.mean = mean
        self.std = std
        self.scale = None

        atomic_mass = torch.from_numpy(ase.data.atomic_masses)
        self.register_buffer('atomic_mass', atomic_mass)

        self.embedding = Embedding(100, hidden_channels)
        self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)

        self.interactions = ModuleList()
        for _ in range(num_interactions):
            block = InteractionBlock(hidden_channels, num_gaussians,
                                     num_filters, cutoff)
            self.interactions.append(block)

        self.lin1 = Linear(hidden_channels, hidden_channels // 2)
        self.act = ShiftedSoftplus()
        self.lin2 = Linear(hidden_channels // 2, 1)

        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)

        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()
        for interaction in self.interactions:
            interaction.reset_parameters()
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        self.lin1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)
        if self.atomref is not None:
            self.atomref.weight.data.copy_(self.initial_atomref)

    @staticmethod
    def from_qm9_pretrained(root: str, dataset: Dataset, target: int):
        import ase
        import schnetpack as spk  # noqa

        assert target >= 0 and target <= 12

        units = [1] * 12
        units[0] = ase.units.Debye
        units[1] = ase.units.Bohr**3
        units[5] = ase.units.Bohr**2

        root = osp.expanduser(osp.normpath(root))
        makedirs(root)
        folder = 'trained_schnet_models'
        if not osp.exists(osp.join(root, folder)):
            path = download_url(SchNet.url, root)
            extract_zip(path, root)
            os.unlink(path)

        name = f'qm9_{qm9_target_dict[target]}'
        path = osp.join(root, 'trained_schnet_models', name, 'split.npz')

        split = np.load(path)
        train_idx = split['train_idx']
        val_idx = split['val_idx']
        test_idx = split['test_idx']

        # Filter the splits to only contain characterized molecules.
        idx = dataset.data.idx
        assoc = idx.new_empty(idx.max().item() + 1)
        assoc[idx] = torch.arange(idx.size(0))

        train_idx = assoc[train_idx[np.isin(train_idx, idx)]]
        val_idx = assoc[val_idx[np.isin(val_idx, idx)]]
        test_idx = assoc[test_idx[np.isin(test_idx, idx)]]

        path = osp.join(root, 'trained_schnet_models', name, 'best_model')

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            state = torch.load(path, map_location='cpu')

        net = SchNet(hidden_channels=128,
                     num_filters=128,
                     num_interactions=6,
                     num_gaussians=50,
                     cutoff=10.0,
                     atomref=dataset.atomref(target))

        net.embedding.weight = state.representation.embedding.weight

        for int1, int2 in zip(state.representation.interactions,
                              net.interactions):
            int2.mlp[0].weight = int1.filter_network[0].weight
            int2.mlp[0].bias = int1.filter_network[0].bias
            int2.mlp[2].weight = int1.filter_network[1].weight
            int2.mlp[2].bias = int1.filter_network[1].bias
            int2.lin.weight = int1.dense.weight
            int2.lin.bias = int1.dense.bias

            int2.conv.lin1.weight = int1.cfconv.in2f.weight
            int2.conv.lin2.weight = int1.cfconv.f2out.weight
            int2.conv.lin2.bias = int1.cfconv.f2out.bias

        net.lin1.weight = state.output_modules[0].out_net[1].out_net[0].weight
        net.lin1.bias = state.output_modules[0].out_net[1].out_net[0].bias
        net.lin2.weight = state.output_modules[0].out_net[1].out_net[1].weight
        net.lin2.bias = state.output_modules[0].out_net[1].out_net[1].bias

        mean = state.output_modules[0].atom_pool.average
        net.readout = 'mean' if mean is True else 'add'

        dipole = state.output_modules[0].__class__.__name__ == 'DipoleMoment'
        net.dipole = dipole

        net.mean = state.output_modules[0].standardize.mean.item()
        net.std = state.output_modules[0].standardize.stddev.item()

        if state.output_modules[0].atomref is not None:
            net.atomref.weight = state.output_modules[0].atomref.weight
        else:
            net.atomref = None

        net.scale = 1. / units[target]

        return net, (dataset[train_idx], dataset[val_idx], dataset[test_idx])

    def forward(self, z, pos, batch=None):
        """"""
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        h = self.embedding(z)

        edge_index = radius_graph(pos,
                                  r=self.cutoff,
                                  batch=batch,
                                  max_num_neighbors=self.max_num_neighbors)
        row, col = edge_index
        edge_weight = (pos[row] - pos[col]).norm(dim=-1)
        edge_attr = self.distance_expansion(edge_weight)

        for interaction in self.interactions:
            h = h + interaction(h, edge_index, edge_weight, edge_attr)

        h = self.lin1(h)
        h = self.act(h)
        h = self.lin2(h)

        if self.dipole:
            # Get center of mass.
            mass = self.atomic_mass[z].view(-1, 1)
            c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
            h = h * (pos - c[batch])

        if not self.dipole and self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if not self.dipole and self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)

        if self.dipole:
            out = torch.norm(out, dim=-1, keepdim=True)

        if self.scale is not None:
            out = self.scale * out

        return out

    def __repr__(self):
        return (f'{self.__class__.__name__}('
                f'hidden_channels={self.hidden_channels}, '
                f'num_filters={self.num_filters}, '
                f'num_interactions={self.num_interactions}, '
                f'num_gaussians={self.num_gaussians}, '
                f'cutoff={self.cutoff})')
예제 #5
0
class MetaPath2Vec(torch.nn.Module):
    r"""The MetaPath2Vec model from the `"metapath2vec: Scalable Representation
    Learning for Heterogeneous Networks"
    <https://ericdongyx.github.io/papers/
    KDD17-dong-chawla-swami-metapath2vec.pdf>`_ paper where random walks based
    on a given :obj:`metapath` are sampled in a heterogeneous graph, and node
    embeddings are learned via negative sampling optimization.

    .. note::

        For an example of using MetaPath2Vec, see `examples/metapath2vec.py
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
        metapath2vec.py>`_.

    Args:
        edge_index_dict (dict): Dictionary holding edge indices for each
            :obj:`(source_node_type, relation_type, target_node_type)` present
            in the heterogeneous graph.
        embedding_dim (int): The size of each embedding vector.
        metapath (list): The metapath described as a list of
            :obj:`(source_node_type, relation_type, target_node_type)` tuples.
        walk_length (int): The walk length.
        context_size (int): The actual context size which is considered for
            positive samples. This parameter increases the effective sampling
            rate by reusing samples across different source nodes.
        walks_per_node (int, optional): The number of walks to sample for each
            node. (default: :obj:`1`)
        num_negative_samples (int, optional): The number of negative samples to
            use for each positive sample. (default: :obj:`1`)
        num_nodes_dict (dict, optional): Dictionary holding the number of nodes
            for each node type. (default: :obj:`None`)
        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
            weight matrix will be sparse. (default: :obj:`False`)
    """
    def __init__(self, edge_index_dict, embedding_dim, metapath, walk_length,
                 context_size, walks_per_node=1, num_negative_samples=1,
                 num_nodes_dict=None, sparse=False):
        super(MetaPath2Vec, self).__init__()

        if num_nodes_dict is None:
            num_nodes_dict = {}
            for keys, edge_index in edge_index_dict.items():
                key = keys[0]
                N = int(edge_index[0].max() + 1)
                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

                key = keys[-1]
                N = int(edge_index[1].max() + 1)
                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

        adj_dict = {}
        for keys, edge_index in edge_index_dict.items():
            sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]])
            row, col = edge_index
            adj = SparseTensor(row=row, col=col, sparse_sizes=sizes)
            adj = adj.to('cpu')
            adj_dict[keys] = adj

        assert metapath[0][0] == metapath[-1][-1]
        assert walk_length >= context_size

        self.adj_dict = adj_dict
        self.embedding_dim = embedding_dim
        self.metapath = metapath
        self.walk_length = walk_length
        self.context_size = context_size
        self.walks_per_node = walks_per_node
        self.num_negative_samples = num_negative_samples
        self.num_nodes_dict = num_nodes_dict

        types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath])
        types = sorted(list(types))

        count = 0
        self.start, self.end = {}, {}
        for key in types:
            self.start[key] = count
            count += num_nodes_dict[key]
            self.end[key] = count

        offset = [self.start[metapath[0][0]]]
        offset += [self.start[keys[-1]] for keys in metapath
                   ] * int((walk_length / len(metapath)) + 1)
        offset = offset[:walk_length + 1]
        assert len(offset) == walk_length + 1
        self.offset = torch.tensor(offset)

        self.embedding = Embedding(count, embedding_dim, sparse=sparse)

        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()

    def forward(self, node_type, batch=None):
        """Returns the embeddings for the nodes in :obj:`subset` of type
        :obj:`node_type`."""
        emb = self.embedding.weight[self.start[node_type]:self.end[node_type]]
        return emb if batch is None else emb[batch]

    def loader(self, **kwargs):
        return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]),
                          collate_fn=self.sample, **kwargs)

    def pos_sample(self, batch):
        # device = self.embedding.weight.device

        batch = batch.repeat(self.walks_per_node)

        rws = [batch]
        for i in range(self.walk_length):
            keys = self.metapath[i % len(self.metapath)]
            adj = self.adj_dict[keys]
            batch = adj.sample(num_neighbors=1, subset=batch).squeeze()
            rws.append(batch)

        rw = torch.stack(rws, dim=-1)
        rw.add_(self.offset.view(1, -1))

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)

    def neg_sample(self, batch):
        batch = batch.repeat(self.walks_per_node * self.num_negative_samples)

        rws = [batch]
        for i in range(self.walk_length):
            keys = self.metapath[i % len(self.metapath)]
            batch = torch.randint(0, self.num_nodes_dict[keys[-1]],
                                  (batch.size(0), ), dtype=torch.long)
            rws.append(batch)

        rw = torch.stack(rws, dim=-1)
        rw.add_(self.offset.view(1, -1))

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)

    def sample(self, batch):
        if not isinstance(batch, torch.Tensor):
            batch = torch.tensor(batch)
        return self.pos_sample(batch), self.neg_sample(batch)

    def loss(self, pos_rw, neg_rw):
        r"""Computes the loss given positive and negative random walks."""

        # Positive loss.
        start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()

        h_start = self.embedding(start).view(pos_rw.size(0), 1,
                                             self.embedding_dim)
        h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1,
                                                    self.embedding_dim)

        out = (h_start * h_rest).sum(dim=-1).view(-1)
        pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()

        # Negative loss.
        start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()

        h_start = self.embedding(start).view(neg_rw.size(0), 1,
                                             self.embedding_dim)
        h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1,
                                                    self.embedding_dim)

        out = (h_start * h_rest).sum(dim=-1).view(-1)
        neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()

        return pos_loss + neg_loss

    def test(self, train_z, train_y, test_z, test_y, solver='lbfgs',
             multi_class='auto', *args, **kwargs):
        r"""Evaluates latent space quality via a logistic regression downstream
        task."""
        from sklearn.linear_model import LogisticRegression

        clf = LogisticRegression(solver=solver, multi_class=multi_class, *args,
                                 **kwargs).fit(train_z.detach().cpu().numpy(),
                                               train_y.detach().cpu().numpy())
        return clf.score(test_z.detach().cpu().numpy(),
                         test_y.detach().cpu().numpy())

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__,
                                   self.embedding.weight.size(0),
                                   self.embedding.weight.size(1))
예제 #6
0
class Node2Vec(torch.nn.Module):
    r"""The Node2Vec model from the
    `"node2vec: Scalable Feature Learning for Networks"
    <https://arxiv.org/abs/1607.00653>`_ paper where random walks of
    length :obj:`walk_length` are sampled in a given graph, and node embeddings
    are learned via negative sampling optimization.

    .. note::

        For an example of using Node2Vec, see `examples/node2vec.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        node2vec.py>`_.

    Args:
        edge_index (LongTensor): The edge indices.
        embedding_dim (int): The size of each embedding vector.
        walk_length (int): The walk length.
        context_size (int): The actual context size which is considered for
            positive samples. This parameter increases the effective sampling
            rate by reusing samples across different source nodes.
        walks_per_node (int, optional): The number of walks to sample for each
            node. (default: :obj:`1`)
        p (float, optional): Likelihood of immediately revisiting a node in the
            walk. (default: :obj:`1`)
        q (float, optional): Control parameter to interpolate between
            breadth-first strategy and depth-first strategy (default: :obj:`1`)
        num_negative_samples (int, optional): The number of negative samples to
            use for each positive sample. (default: :obj:`1`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
            weight matrix will be sparse. (default: :obj:`False`)
    """
    def __init__(self, edge_index, embedding_dim, walk_length, context_size,
                 walks_per_node=1, p=1, q=1, num_negative_samples=1,
                 num_nodes=None, sparse=False):
        super().__init__()

        if random_walk is None:
            raise ImportError('`Node2Vec` requires `torch-cluster`.')

        N = maybe_num_nodes(edge_index, num_nodes)
        row, col = edge_index
        self.adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
        self.adj = self.adj.to('cpu')

        assert walk_length >= context_size

        self.embedding_dim = embedding_dim
        self.walk_length = walk_length - 1
        self.context_size = context_size
        self.walks_per_node = walks_per_node
        self.p = p
        self.q = q
        self.num_negative_samples = num_negative_samples

        self.embedding = Embedding(N, embedding_dim, sparse=sparse)

        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()

    def forward(self, batch=None):
        """Returns the embeddings for the nodes in :obj:`batch`."""
        emb = self.embedding.weight
        return emb if batch is None else emb.index_select(0, batch)

    def loader(self, **kwargs):
        return DataLoader(range(self.adj.sparse_size(0)),
                          collate_fn=self.sample, **kwargs)

    def pos_sample(self, batch):
        batch = batch.repeat(self.walks_per_node)
        rowptr, col, _ = self.adj.csr()
        rw = random_walk(rowptr, col, batch, self.walk_length, self.p, self.q)
        if not isinstance(rw, torch.Tensor):
            rw = rw[0]

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)

    def neg_sample(self, batch):
        batch = batch.repeat(self.walks_per_node * self.num_negative_samples)

        rw = torch.randint(self.adj.sparse_size(0),
                           (batch.size(0), self.walk_length))
        rw = torch.cat([batch.view(-1, 1), rw], dim=-1)

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)

    def sample(self, batch):
        if not isinstance(batch, torch.Tensor):
            batch = torch.tensor(batch)
        return self.pos_sample(batch), self.neg_sample(batch)

    def loss(self, pos_rw, neg_rw):
        r"""Computes the loss given positive and negative random walks."""

        # Positive loss.
        start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()

        h_start = self.embedding(start).view(pos_rw.size(0), 1,
                                             self.embedding_dim)
        h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1,
                                                    self.embedding_dim)

        out = (h_start * h_rest).sum(dim=-1).view(-1)
        pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()

        # Negative loss.
        start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()

        h_start = self.embedding(start).view(neg_rw.size(0), 1,
                                             self.embedding_dim)
        h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1,
                                                    self.embedding_dim)

        out = (h_start * h_rest).sum(dim=-1).view(-1)
        neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()

        return pos_loss + neg_loss

    def test(self, train_z, train_y, test_z, test_y, solver='lbfgs',
             multi_class='auto', *args, **kwargs):
        r"""Evaluates latent space quality via a logistic regression downstream
        task."""
        from sklearn.linear_model import LogisticRegression

        clf = LogisticRegression(solver=solver, multi_class=multi_class, *args,
                                 **kwargs).fit(train_z.detach().cpu().numpy(),
                                               train_y.detach().cpu().numpy())
        return clf.score(test_z.detach().cpu().numpy(),
                         test_y.detach().cpu().numpy())

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.embedding.weight.size(0)}, '
                f'{self.embedding.weight.size(1)})')
예제 #7
0
class MetaPath2Vec(torch.nn.Module):
    r"""The MetaPath2Vec model from the `"metapath2vec: Scalable Representation
    Learning for Heterogeneous Networks"
    <https://ericdongyx.github.io/papers/
    KDD17-dong-chawla-swami-metapath2vec.pdf>`_ paper where random walks based
    on a given :obj:`metapath` are sampled in a heterogeneous graph, and node
    embeddings are learned via negative sampling optimization.

    .. note::

        For an example of using MetaPath2Vec, see
        `examples/hetero/metapath2vec.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        hetero/metapath2vec.py>`_.

    Args:
        edge_index_dict (Dict[Tuple[str, str, str], Tensor]): Dictionary
            holding edge indices for each
            :obj:`(src_node_type, rel_type, dst_node_type)` present in the
            heterogeneous graph.
        embedding_dim (int): The size of each embedding vector.
        metapath (List[Tuple[str, str, str]]): The metapath described as a list
            of :obj:`(src_node_type, rel_type, dst_node_type)` tuples.
        walk_length (int): The walk length.
        context_size (int): The actual context size which is considered for
            positive samples. This parameter increases the effective sampling
            rate by reusing samples across different source nodes.
        walks_per_node (int, optional): The number of walks to sample for each
            node. (default: :obj:`1`)
        num_negative_samples (int, optional): The number of negative samples to
            use for each positive sample. (default: :obj:`1`)
        num_nodes_dict (Dict[str, int], optional): Dictionary holding the
            number of nodes for each node type. (default: :obj:`None`)
        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
            weight matrix will be sparse. (default: :obj:`False`)
    """
    def __init__(
        self,
        edge_index_dict: Dict[EdgeType, Tensor],
        embedding_dim: int,
        metapath: List[EdgeType],
        walk_length: int,
        context_size: int,
        walks_per_node: int = 1,
        num_negative_samples: int = 1,
        num_nodes_dict: Optional[Dict[NodeType, int]] = None,
        sparse: bool = False,
    ):
        super().__init__()

        if num_nodes_dict is None:
            num_nodes_dict = {}
            for keys, edge_index in edge_index_dict.items():
                key = keys[0]
                N = int(edge_index[0].max() + 1)
                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

                key = keys[-1]
                N = int(edge_index[1].max() + 1)
                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

        adj_dict = {}
        for keys, edge_index in edge_index_dict.items():
            sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]])
            row, col = edge_index
            adj = SparseTensor(row=row, col=col, sparse_sizes=sizes)
            adj = adj.to('cpu')
            adj_dict[keys] = adj

        assert walk_length + 1 >= context_size
        if walk_length > len(metapath) and metapath[0][0] != metapath[-1][-1]:
            raise AttributeError(
                "The 'walk_length' is longer than the given 'metapath', but "
                "the 'metapath' does not denote a cycle")

        self.adj_dict = adj_dict
        self.embedding_dim = embedding_dim
        self.metapath = metapath
        self.walk_length = walk_length
        self.context_size = context_size
        self.walks_per_node = walks_per_node
        self.num_negative_samples = num_negative_samples
        self.num_nodes_dict = num_nodes_dict

        types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath])
        types = sorted(list(types))

        count = 0
        self.start, self.end = {}, {}
        for key in types:
            self.start[key] = count
            count += num_nodes_dict[key]
            self.end[key] = count

        offset = [self.start[metapath[0][0]]]
        offset += [self.start[keys[-1]] for keys in metapath
                   ] * int((walk_length / len(metapath)) + 1)
        offset = offset[:walk_length + 1]
        assert len(offset) == walk_length + 1
        self.offset = torch.tensor(offset)

        # + 1 denotes a dummy node used to link to for isolated nodes.
        self.embedding = Embedding(count + 1, embedding_dim, sparse=sparse)
        self.dummy_idx = count

        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()

    def forward(self, node_type: str, batch: OptTensor = None) -> Tensor:
        r"""Returns the embeddings for the nodes in :obj:`batch` of type
        :obj:`node_type`."""
        emb = self.embedding.weight[self.start[node_type]:self.end[node_type]]
        return emb if batch is None else emb.index_select(0, batch)

    def loader(self, **kwargs):
        r"""Returns the data loader that creates both positive and negative
        random walks on the heterogeneous graph.

        Args:
            **kwargs (optional): Arguments of
                :class:`torch.utils.data.DataLoader`, such as
                :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or
                :obj:`num_workers`.
        """
        return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]),
                          collate_fn=self._sample, **kwargs)

    def _pos_sample(self, batch: Tensor) -> Tensor:
        batch = batch.repeat(self.walks_per_node)

        rws = [batch]
        for i in range(self.walk_length):
            keys = self.metapath[i % len(self.metapath)]
            adj = self.adj_dict[keys]
            batch = sample(adj, batch, num_neighbors=1,
                           dummy_idx=self.dummy_idx).view(-1)
            rws.append(batch)

        rw = torch.stack(rws, dim=-1)
        rw.add_(self.offset.view(1, -1))
        rw[rw > self.dummy_idx] = self.dummy_idx

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)

    def _neg_sample(self, batch: Tensor) -> Tensor:
        batch = batch.repeat(self.walks_per_node * self.num_negative_samples)

        rws = [batch]
        for i in range(self.walk_length):
            keys = self.metapath[i % len(self.metapath)]
            batch = torch.randint(0, self.num_nodes_dict[keys[-1]],
                                  (batch.size(0), ), dtype=torch.long)
            rws.append(batch)

        rw = torch.stack(rws, dim=-1)
        rw.add_(self.offset.view(1, -1))

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)

    def _sample(self, batch: List[int]) -> Tuple[Tensor, Tensor]:
        if not isinstance(batch, Tensor):
            batch = torch.tensor(batch, dtype=torch.long)
        return self._pos_sample(batch), self._neg_sample(batch)

    def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:
        r"""Computes the loss given positive and negative random walks."""

        # Positive loss.
        start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()

        h_start = self.embedding(start).view(pos_rw.size(0), 1,
                                             self.embedding_dim)
        h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1,
                                                    self.embedding_dim)

        out = (h_start * h_rest).sum(dim=-1).view(-1)
        pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()

        # Negative loss.
        start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()

        h_start = self.embedding(start).view(neg_rw.size(0), 1,
                                             self.embedding_dim)
        h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1,
                                                    self.embedding_dim)

        out = (h_start * h_rest).sum(dim=-1).view(-1)
        neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()

        return pos_loss + neg_loss

    def test(self, train_z: Tensor, train_y: Tensor, test_z: Tensor,
             test_y: Tensor, solver: str = "lbfgs", multi_class: str = "auto",
             *args, **kwargs) -> float:
        r"""Evaluates latent space quality via a logistic regression downstream
        task."""
        from sklearn.linear_model import LogisticRegression

        clf = LogisticRegression(solver=solver, multi_class=multi_class, *args,
                                 **kwargs).fit(train_z.detach().cpu().numpy(),
                                               train_y.detach().cpu().numpy())
        return clf.score(test_z.detach().cpu().numpy(),
                         test_y.detach().cpu().numpy())

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}('
                f'{self.embedding.weight.size(0) - 1}, '
                f'{self.embedding.weight.size(1)})')
예제 #8
0
class HEATConv(MessagePassing):
    r"""The heterogeneous edge-enhanced graph attentional operator from the
    `"Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent
    Trajectory Prediction" <https://arxiv.org/abs/2106.07161>`_ paper, which
    enhances :class:`~torch_geometric.nn.conv.GATConv` by:

    1. type-specific transformations of nodes of different types
    2. edge type and edge feature incorporation, in which edges are assumed to
       have different types but contain the same kind of attributes

    Args:
        in_channels (int): Size of each input sample, or :obj:`-1` to derive
            the size from the first input(s) to the forward method.
        out_channels (int): Size of each output sample.
        num_node_types (int): The number of node types.
        num_edge_types (int): The number of edge types.
        edge_type_emb_dim (int): The embedding size of edge types.
        edge_dim (int): Edge feature dimensionality.
        edge_attr_emb_dim (int): The embedding size of edge features.
        heads (int, optional): Number of multi-head-attentions.
            (default: :obj:`1`)
        concat (bool, optional): If set to :obj:`False`, the multi-head
            attentions are averaged instead of concatenated.
            (default: :obj:`True`)
        negative_slope (float, optional): LeakyReLU angle of the negative
            slope. (default: :obj:`0.2`)
        dropout (float, optional): Dropout probability of the normalized
            attention coefficients which exposes each node to a stochastically
            sampled neighborhood during training. (default: :obj:`0`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 num_node_types: int,
                 num_edge_types: int,
                 edge_type_emb_dim: int,
                 edge_dim: int,
                 edge_attr_emb_dim: int,
                 heads: int = 1,
                 concat: bool = True,
                 negative_slope: float = 0.2,
                 dropout: float = 0.0,
                 root_weight: bool = True,
                 bias: bool = True,
                 **kwargs):

        kwargs.setdefault('aggr', 'add')
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.root_weight = root_weight

        self.hetero_lin = HeteroLinear(in_channels,
                                       out_channels,
                                       num_node_types,
                                       bias=bias)

        self.edge_type_emb = Embedding(num_edge_types, edge_type_emb_dim)
        self.edge_attr_emb = Linear(edge_dim, edge_attr_emb_dim, bias=False)

        self.att = Linear(2 * out_channels + edge_type_emb_dim +
                          edge_attr_emb_dim,
                          self.heads,
                          bias=False)

        self.lin = Linear(out_channels + edge_attr_emb_dim,
                          out_channels,
                          bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        self.hetero_lin.reset_parameters()
        self.edge_type_emb.reset_parameters()
        self.edge_attr_emb.reset_parameters()
        self.att.reset_parameters()
        self.lin.reset_parameters()

    def forward(self,
                x: Tensor,
                edge_index: Adj,
                node_type: Tensor,
                edge_type: Tensor,
                edge_attr: OptTensor = None,
                size: Size = None) -> Tensor:
        """"""
        x = self.hetero_lin(x, node_type)

        edge_type_emb = F.leaky_relu(self.edge_type_emb(edge_type),
                                     self.negative_slope)

        # propagate_type: (x: Tensor, edge_type_emb: Tensor, edge_attr: OptTensor)  # noqa
        out = self.propagate(edge_index,
                             x=x,
                             edge_type_emb=edge_type_emb,
                             edge_attr=edge_attr,
                             size=size)

        if self.concat:
            if self.root_weight:
                out += x.view(-1, 1, self.out_channels)
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)
            if self.root_weight:
                out += x

        return out

    def message(self, x_i: Tensor, x_j: Tensor, edge_type_emb: Tensor,
                edge_attr: Tensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:

        edge_attr = F.leaky_relu(self.edge_attr_emb(edge_attr),
                                 self.negative_slope)

        alpha = torch.cat([x_i, x_j, edge_type_emb, edge_attr], dim=-1)
        alpha = F.leaky_relu(self.att(alpha), self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        out = self.lin(torch.cat([x_j, edge_attr], dim=-1)).unsqueeze(-2)
        return out * alpha.unsqueeze(-1)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')
예제 #9
0
class Net(torch.nn.Module):
    def __init__(self, num_classes, gnn_layers, embed_dim, hidden_dim,
                 jk_layer, process_step, dropout):
        super(Net, self).__init__()

        self.dropout = dropout
        self.convs = torch.nn.ModuleList()
        self.embedding = Embedding(6, embed_dim)

        for i in range(gnn_layers):
            if i == 0:
                self.convs.append(
                    AGGINConv(Sequential(Linear(2 * embed_dim + 2,
                                                hidden_dim), ReLU(),
                                         Linear(hidden_dim, hidden_dim),
                                         ReLU(), BN(hidden_dim)),
                              train_eps=True))
            else:
                self.convs.append(
                    AGGINConv(Sequential(Linear(hidden_dim,
                                                hidden_dim), ReLU(),
                                         Linear(hidden_dim, hidden_dim),
                                         ReLU(), BN(hidden_dim)),
                              train_eps=True))

        if jk_layer.isdigit():
            jk_layer = int(jk_layer)
            self.jk = JumpingKnowledge(mode='lstm',
                                       channels=hidden_dim,
                                       gnn_layers=jk_layer)
            self.s2s = (Set2Set(hidden_dim, processing_steps=process_step))
            self.fc1 = Linear(2 * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)
        elif jk_layer == 'cat':
            self.jk = JumpingKnowledge(mode=jk_layer)
            self.s2s = (Set2Set(gnn_layers * hidden_dim,
                                processing_steps=process_step))
            self.fc1 = Linear(2 * gnn_layers * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)
        elif jk_layer == 'max':
            self.jk = JumpingKnowledge(mode=jk_layer)
            self.s2s = (Set2Set(hidden_dim, processing_steps=process_step))
            self.fc1 = Linear(2 * hidden_dim, hidden_dim)
            self.fc2 = Linear(hidden_dim, int(hidden_dim / 2))
            self.fc3 = Linear(int(hidden_dim / 2), num_classes)

    def reset_parameters(self):
        self.embedding.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jk.reset_parameters()
        self.s2s.reset_parameters()
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()
        self.fc3.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Embedding the categorical values from Gene expression and Node type
        xc = x[:, :2].type(torch.long)
        ems = self.embedding(xc)
        ems = ems.view(-1, ems.shape[1] * ems.shape[2])
        x = torch.cat((ems, x[:, 2:]), dim=1)

        xs = []
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            xs += [x]
        x = self.jk(xs)
        x = self.s2s(x, batch)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)

        return logits
예제 #10
0
class Net(torch.nn.Module):
    def __init__(self,
                 hidden_channels,
                 out_channels,
                 num_layers,
                 dropout=0.0,
                 inter_message_passing=True):
        super(Net, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        self.inter_message_passing = inter_message_passing

        self.atom_encoder = AtomEncoder(hidden_channels)
        self.clique_encoder = Embedding(4, hidden_channels)

        self.bond_encoders = ModuleList()
        self.atom_convs = ModuleList()
        self.atom_batch_norms = ModuleList()

        for _ in range(num_layers):
            self.bond_encoders.append(BondEncoder(hidden_channels))
            nn = Sequential(
                Linear(hidden_channels, 2 * hidden_channels),
                BatchNorm1d(2 * hidden_channels),
                ReLU(),
                Linear(2 * hidden_channels, hidden_channels),
            )
            self.atom_convs.append(GINEConv(nn, train_eps=True))
            self.atom_batch_norms.append(BatchNorm1d(hidden_channels))

        self.clique_convs = ModuleList()
        self.clique_batch_norms = ModuleList()

        for _ in range(num_layers):
            nn = Sequential(
                Linear(hidden_channels, 2 * hidden_channels),
                BatchNorm1d(2 * hidden_channels),
                ReLU(),
                Linear(2 * hidden_channels, hidden_channels),
            )
            self.clique_convs.append(GINConv(nn, train_eps=True))
            self.clique_batch_norms.append(BatchNorm1d(hidden_channels))

        self.atom2clique_lins = ModuleList()
        self.clique2atom_lins = ModuleList()

        for _ in range(num_layers):
            self.atom2clique_lins.append(
                Linear(hidden_channels, hidden_channels))
            self.clique2atom_lins.append(
                Linear(hidden_channels, hidden_channels))

        self.atom_lin = Linear(hidden_channels, hidden_channels)
        self.clique_lin = Linear(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        self.atom_encoder.reset_parameters()
        self.clique_encoder.reset_parameters()

        for emb, conv, batch_norm in zip(self.bond_encoders, self.atom_convs,
                                         self.atom_batch_norms):
            emb.reset_parameters()
            conv.reset_parameters()
            batch_norm.reset_parameters()

        for conv, batch_norm in zip(self.clique_convs,
                                    self.clique_batch_norms):
            conv.reset_parameters()
            batch_norm.reset_parameters()

        for lin1, lin2 in zip(self.atom2clique_lins, self.clique2atom_lins):
            lin1.reset_parameters()
            lin2.reset_parameters()

        self.atom_lin.reset_parameters()
        self.clique_lin.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, data):
        x = self.atom_encoder(data.x.squeeze())

        if self.inter_message_passing:
            x_clique = self.clique_encoder(data.x_clique.squeeze())

        for i in range(self.num_layers):
            edge_attr = self.bond_encoders[i](data.edge_attr)
            x = self.atom_convs[i](x, data.edge_index, edge_attr)
            x = self.atom_batch_norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)

            if self.inter_message_passing:
                row, col = data.atom2clique_index

                x_clique = x_clique + F.relu(self.atom2clique_lins[i](scatter(
                    x[row],
                    col,
                    dim=0,
                    dim_size=x_clique.size(0),
                    reduce='mean')))

                x_clique = self.clique_convs[i](x_clique, data.tree_edge_index)
                x_clique = self.clique_batch_norms[i](x_clique)
                x_clique = F.relu(x_clique)
                x_clique = F.dropout(x_clique,
                                     self.dropout,
                                     training=self.training)

                x = x + F.relu(self.clique2atom_lins[i](scatter(
                    x_clique[col],
                    row,
                    dim=0,
                    dim_size=x.size(0),
                    reduce='mean')))

        x = scatter(x, data.batch, dim=0, reduce='mean')
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.atom_lin(x)

        if self.inter_message_passing:
            tree_batch = torch.repeat_interleave(data.num_cliques)
            x_clique = scatter(x_clique,
                               tree_batch,
                               dim=0,
                               dim_size=x.size(0),
                               reduce='mean')
            x_clique = F.dropout(x_clique,
                                 self.dropout,
                                 training=self.training)
            x_clique = self.clique_lin(x_clique)
            x = x + x_clique

        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.lin(x)
        return x
예제 #11
0
class LightGCN(torch.nn.Module):
    r"""The LightGCN model from the `"LightGCN: Simplifying and Powering
    Graph Convolution Network for Recommendation"
    <https://arxiv.org/abs/2002.02126>`_ paper.

    :class:`~torch_geometric.nn.models.LightGCN` learns embeddings by linearly
    propagating them on the underlying graph, and uses the weighted sum of the
    embeddings learned at all layers as the final embedding

    .. math::
        \textbf{x}_i = \sum_{l=0}^{L} \alpha_l \textbf{x}^{(l)}_i,

    where each layer's embedding is computed as

    .. math::
        \mathbf{x}^{(l+1)}_i = \sum_{j \in \mathcal{N}(i)}
        \frac{1}{\sqrt{\deg(i)\deg(j)}}\mathbf{x}^{(l)}_j.

    Two prediction heads and trainign objectives are provided:
    **link prediction** (via
    :meth:`~torch_geometric.nn.models.LightGCN.link_pred_loss` and
    :meth:`~torch_geometric.nn.models.LightGCN.predict_link`) and
    **recommendation** (via
    :meth:`~torch_geometric.nn.models.LightGCN.recommendation_loss` and
    :meth:`~torch_geometric.nn.models.LightGCN.recommend`).

    .. note::

        Embeddings are propagated according to the graph connectivity specified
        by :obj:`edge_index` while rankings or link probabilities are computed
        according to the edges specified by :obj:`edge_label_index`.

    Args:
        num_nodes (int): The number of nodes in the graph.
        embedding_dim (int): The dimensionality of node embeddings.
        num_layers (int): The number of
            :class:`~torch_geometric.nn.conv.LGConv` layers.
        alpha (float or Tensor, optional): The scalar or vector specifying the
            re-weighting coefficients for aggregating the final embedding.
            If set to :obj:`None`, the uniform initialization of
            :obj:`1 / (num_layers + 1)` is used. (default: :obj:`None`)
        **kwargs (optional): Additional arguments of the underlying
            :class:`~torch_geometric.nn.conv.LGConv` layers.
    """
    def __init__(
        self,
        num_nodes: int,
        embedding_dim: int,
        num_layers: int,
        alpha: Optional[Union[float, Tensor]] = None,
        **kwargs,
    ):
        super().__init__()

        self.num_nodes = num_nodes
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        if alpha is None:
            alpha = 1. / (num_layers + 1)

        if isinstance(alpha, Tensor):
            assert alpha.size(0) == num_layers + 1
        else:
            alpha = torch.tensor([alpha] * (num_layers + 1))
        self.register_buffer('alpha', alpha)

        self.embedding = Embedding(num_nodes, embedding_dim)
        self.convs = ModuleList([LGConv(**kwargs) for _ in range(num_layers)])

        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()

    def get_embedding(self, edge_index: Adj) -> Tensor:
        x = self.embedding.weight
        out = x * self.alpha[0]

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            out = out + x * self.alpha[i + 1]

        return out

    def forward(self,
                edge_index: Adj,
                edge_label_index: OptTensor = None) -> Tensor:
        r"""Computes rankings for pairs of nodes.

        Args:
            edge_index (Tensor or SparseTensor): Edge tensor specifying the
                connectivity of the graph.
            edge_label_index (Tensor, optional): Edge tensor specifying the
                node pairs for which to compute rankings or probabilities.
                If :obj:`edge_label_index` is set to :obj:`None`, all edges in
                :obj:`edge_index` will be used instead. (default: :obj:`None`)
        """
        if edge_label_index is None:
            edge_label_index = edge_index

        out = self.get_embedding(edge_index)

        out_src = out[edge_label_index[0]]
        out_dst = out[edge_label_index[1]]
        return (out_src * out_dst).sum(dim=-1)

    def predict_link(self,
                     edge_index: Adj,
                     edge_label_index: OptTensor = None,
                     prob: bool = False) -> Tensor:
        r"""Predict links between nodes specified in :obj:`edge_label_index`.

        Args:
            prob (bool): Whether probabilities should be returned. (default:
                :obj:`False`)
        """
        pred = self(edge_index, edge_label_index).sigmoid()
        return pred if prob else pred.round()

    def recommend(self,
                  edge_index: Adj,
                  src_index: OptTensor = None,
                  dst_index: OptTensor = None,
                  k: int = 1) -> Tensor:
        r"""Get top-:math:`k` recommendations for nodes in :obj:`src_index`.

        Args:
            src_index (Tensor, optional): Node indices for which
                recommendations should be generated.
                If set to :obj:`None`, all nodes will be used.
                (default: :obj:`None`)
            dst_index (Tensor, optional): Node indices which represent the
                possible recommendation choices.
                If set to :obj:`None`, all nodes will be used.
                (default: :obj:`None`)
            k (int, optional): Number of recommendations. (default: :obj:`1`)
        """
        out_src = out_dst = self.get_embedding(edge_index)

        if src_index is not None:
            out_src = out_src[src_index]

        if dst_index is not None:
            out_dst = out_dst[dst_index]

        pred = out_src @ out_dst.t()
        top_index = pred.topk(k, dim=-1).indices

        if dst_index is not None:  # Map local top-indices to original indices.
            top_index = dst_index[top_index.view(-1)].view(*top_index.size())

        return top_index

    def link_pred_loss(self, pred: Tensor, edge_label: Tensor,
                       **kwargs) -> Tensor:
        r"""Computes the model loss for a link prediction objective via the
        :class:`torch.nn.BCEWithLogitsLoss`.

        Args:
            pred (Tensor): The predictions.
            edge_label (Tensor): The ground-truth edge labels.
            **kwargs (optional): Additional arguments of the underlying
                :class:`torch.nn.BCEWithLogitsLoss` loss function.
        """
        loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs)
        return loss_fn(pred, edge_label.to(pred.dtype))

    def recommendation_loss(self,
                            pos_edge_rank: Tensor,
                            neg_edge_rank: Tensor,
                            lambda_reg: float = 1e-4,
                            **kwargs) -> Tensor:
        r"""Computes the model loss for a ranking objective via the Bayesian
        Personalized Ranking (BPR) loss.

        .. note::

            The i-th entry in the :obj:`pos_edge_rank` vector and i-th entry
            in the :obj:`neg_edge_rank` entry must correspond to ranks of
            positive and negative edges of the same entity (*e.g.*, user).

        Args:
            pos_edge_rank (Tensor): Positive edge rankings.
            neg_edge_rank (Tensor): Negative edge rankings.
            lambda_reg (int, optional): The :math:`L_2` regularization strength
                of the Bayesian Personalized Ranking (BPR) loss.
                (default: 1e-4)
            **kwargs (optional): Additional arguments of the underlying
                :class:`torch_geometric.nn.models.lightgcn.BPRLoss` loss
                function.
        """
        loss_fn = BPRLoss(lambda_reg, **kwargs)
        return loss_fn(pos_edge_rank, neg_edge_rank, self.embedding.weight)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.num_nodes}, '
                f'{self.embedding_dim}, num_layers={self.num_layers})')
예제 #12
0
class SchNet(torch.nn.Module):
    r"""
        The re-implementation for SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" <https://arxiv.org/abs/1706.08566>`_ paper
        under the 3DGN gramework from `"Spherical Message Passing for 3D Graph Networks" <https://arxiv.org/abs/2102.05013v2>`_ paper.
        
        Args:
            energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the negative of the derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)
            num_layers (int, optional): The number of layers. (default: :obj:`6`)
            hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`)
            num_filters (int, optional): The number of filters to use. (default: :obj:`128`)
            num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`)
            cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`).
    """
    def __init__(self,
                 energy_and_force=False,
                 cutoff=10.0,
                 num_layers=6,
                 hidden_channels=128,
                 num_filters=128,
                 num_gaussians=50):
        super(SchNet, self).__init__()

        self.energy_and_force = energy_and_force
        self.cutoff = cutoff
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels
        self.num_filters = num_filters
        self.num_gaussians = num_gaussians

        self.init_v = Embedding(100, hidden_channels)
        self.dist_emb = emb(0.0, cutoff, num_gaussians)

        self.update_vs = torch.nn.ModuleList([
            update_v(hidden_channels, num_filters) for _ in range(num_layers)
        ])

        self.update_es = torch.nn.ModuleList([
            update_e(hidden_channels, num_filters, num_gaussians, cutoff)
            for _ in range(num_layers)
        ])

        self.update_u = update_u(hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.init_v.reset_parameters()
        for update_e in self.update_es:
            update_e.reset_parameters()
        for update_v in self.update_vs:
            update_v.reset_parameters()
        self.update_u.reset_parameters()

    def forward(self, batch_data):
        z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch
        if self.energy_and_force:
            pos.requires_grad_()

        edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
        row, col = edge_index
        dist = (pos[row] - pos[col]).norm(dim=-1)
        dist_emb = self.dist_emb(dist)

        v = self.init_v(z)

        for update_e, update_v in zip(self.update_es, self.update_vs):
            e = update_e(v, dist, dist_emb, edge_index)
            v = update_v(v, e, edge_index)
        u = self.update_u(v, batch)

        return u
예제 #13
0
class BaseCrystalModel(Module):
    """
    Base model for crystal problem.
    """
    def __init__(self,
                 num_node_features=1,
                 num_edge_features=3,
                 num_state_features=0,
                 num_node_hidden_channels=128,
                 num_node_interaction_channels=128,
                 num_interactions=1,
                 num_edge_gaussians=None,
                 num_node_embeddings=120,
                 cutoff=10.0,
                 out_size=1,
                 readout='add',
                 mean=None,
                 std=None,
                 norm=False,
                 atom_ref=None,
                 simple_z=True,
                 interactions=None,
                 readout_layer=None,
                 add_state=False,
                 **kwargs):
        """
        Args:
            num_node_features: (int) input number of node feature (atom feature).
            num_edge_features: (int) input number of bond feature.
            if ``num_edge_gaussians`` offered, this parameter is neglect.
            num_state_features: (int) input number of state feature.
            num_node_embeddings: (int) number of embeddings, For generate the initial embedding matrix
            to on behalf of node feature.
            num_node_hidden_channels: (int) num_node_hidden_channels for node feature.
            num_node_interaction_channels: (int) channels for node feature.
            num_interactions: (int) conv number.
            num_edge_gaussians: (int) number of gaussian Smearing number for radius. deprecated
            keep this compact with your bond data.
            cutoff: (float) cutoff for calculate neighbor bond
            readout: (str) Merge node method. such as "add","mean","max","mean".
            mean: (float) mean
            std: (float) std
            norm:(bool) False or True
            atom_ref: (torch.tensor shape (120,1)) properties for atom. such as target y is volumes of compound,
            atom_ref could be the atom volumes of all atom (H,H,He,Li,...). And you could copy the first term to
            make sure the `H` index start form 1.
            simple_z: (bool,str) just used "z" or used "x" to calculate.
            interactions: (Callable) torch module for interactions dynamically: pass the torch module to interactions parameter.static: re-define the ``get_interactions_layer`` and keep this parameter is None.
            the forward input is (h, edge_index, edge_weight, edge_attr, data=data)
            readout_layer: (Callable) torch module for interactions  dynamically: pass the torch module to interactions parameter. static: re-define the ``get_interactions_layer`` and keep this parameter is None. the forward input is (out,)
            add_state: (bool) add state attribute before output.
            out_size:(int) number of out size. for regression,is 1 and for classification should be defined.
        """
        super(BaseCrystalModel, self).__init__()

        self.interaction_kwargs = {}
        for k, v in kwargs.items():
            if "interaction_kwargs_" in k:
                self.interaction_kwargs[k.replace("interaction_kwargs_",
                                                  "")] = v

        self.readout_kwargs = {}
        for k, v in kwargs.items():
            if "readout_kwargs_" in k:
                self.readout_kwargs[k.replace("readout_kwargs_", "")] = v

        # 初始定义
        if num_edge_gaussians is None:
            num_edge_gaussians = num_edge_features

        assert readout in ['add', 'sum', 'min', 'mean', "max"]

        self.num_node_hidden_channels = num_node_hidden_channels
        self.num_state_features = num_state_features
        self.num_node_interaction_channels = num_node_interaction_channels
        self.num_interactions = num_interactions
        self.num_edge_gaussians = num_edge_gaussians
        self.cutoff = cutoff
        self.readout = readout

        self.mean = mean
        self.std = std
        self.scale = None
        self.simple_z = simple_z
        self.interactions = interactions
        self.readout_layer = readout_layer
        self.out_size = out_size
        self.norm = norm

        # 嵌入原子属性,备用
        # (嵌入别太多,容易慢,大多数情况下用不到。)
        atomic_mass = torch.from_numpy(ase_data.atomic_masses)  # 嵌入原子质量
        covalent_radii = torch.from_numpy(ase_data.covalent_radii)  # 嵌入共价半径
        self.register_buffer('atomic_mass', atomic_mass)
        self.register_buffer('atomic_radii', covalent_radii)
        # 缓冲buffer必须要登记注册才会有效,如果仅仅将张量赋值给Module模块的属性,不会被自动转为缓冲buffer.
        # 因而也无法被state_dict()、buffers()、named_buffers()访问到。

        # 定义输入
        # 使用原子性质,或者使用Embedding 产生随机数据。
        # 使用键性质,或者使用Embedding 产生随机数据。
        if num_node_embeddings < 120:
            print(
                "default, num_node_embeddings>=120,if you want simple the net work and "
                "This network does not apply to other elements, the num_node_embeddings could be less but large than "
                "the element type number in your data.")

        # 原子个数,一般不用动,这是所有原子种类数,
        # 一般来说,采用embedding的网络,
        # 在向其他元素(训练集中没有的)数据推广的能力较差。

        if simple_z is True:
            if num_node_features != 0:
                warnings.warn(
                    "simple_z just accept num_node_features == 0, "
                    "and don't use your self-defined 'x' data, but element number Z",
                    UserWarning)

            self.embedding_e = Embedding(num_node_embeddings,
                                         num_node_hidden_channels)
            # self.embedding_l = Linear(2, 2)  # not used
            # self.embedding_l2 = Linear(2, 2)  # not used
        elif self.simple_z == "no_embed":
            # self.embedding_e = Linear(2, 2)
            self.embedding_l = Linear(num_node_features,
                                      num_node_hidden_channels)
            self.embedding_l2 = Linear(num_node_hidden_channels,
                                       num_node_hidden_channels)
        else:
            assert num_node_features > 0, "The `num_node_features` must be the same size with `x` feature."
            self.embedding_e = Embedding(num_node_embeddings,
                                         num_node_hidden_channels)
            self.embedding_l = Linear(num_node_features,
                                      num_node_hidden_channels)
            self.embedding_l2 = Linear(num_node_hidden_channels,
                                       num_node_hidden_channels)

        self.bn = BatchNorm1d(num_node_hidden_channels)

        # 交互层 需要自定义 get_interactions_layer
        if interactions is None:
            self.get_interactions_layer()
        elif isinstance(interactions, ModuleList):
            self.get_res_interactions_layer(interactions)
        elif isinstance(interactions, Module):
            self.interactions = interactions
        else:
            raise NotImplementedError(
                "please implement get_interactions_layer function, "
                "or pass interactions parameters.")
        # 合并层 需要自定义
        if readout_layer is None:
            self.get_readout_layer()
        elif isinstance(interactions, Module):
            self.readout_layer = readout_layer
        else:
            raise NotImplementedError(
                "please implement get_readout_layer function, "
                "or pass readout_layer parameters.")

        self.register_buffer('initial_atom_ref', atom_ref)

        if atom_ref is None:
            self.atom_ref = atom_ref
        elif isinstance(atom_ref, Tensor) and atom_ref.shape[0] == 120:
            self.atom_ref = lambda x: atom_ref[x]
        elif atom_ref == "embed":
            self.atom_ref = Embedding(120, 1)
            self.atom_ref.weight.data.copy_(atom_ref)
            # 单个原子的性质是否加到最后
        else:
            self.atom_ref = atom_ref

        self.add_state = add_state

        if self.add_state:
            self.dp = LayerNorm(self.num_state_features)
            self.ads = Linear(self.num_state_features, 1)
            self.ads2 = Linear(1, self.out_size)

        self.reset_parameters()

    def forward(self, data):
        # 使用embedding 作为假的原子特征输入,而没有用原子特征输入
        assert hasattr(data, "z")
        assert hasattr(data, "pos")
        assert hasattr(data, "batch")
        z = data.z
        batch = data.batch
        # pos = data.pos
        # 处理数据阶段
        if self.simple_z is True:
            # 处理数据阶段
            assert z.dim() == 1 and z.dtype == torch.long

            h = self.embedding_e(z)
            h = F.softplus(h)
        elif self.simple_z == "no_embed":
            assert hasattr(data, "x")
            x = data.x
            h = F.softplus(self.embedding_l(x))
            h = self.embedding_l2(h)
        else:
            assert hasattr(data, "x")
            x = data.x
            h1 = self.embedding_e(z)
            x = F.softplus(self.embedding_l(x))
            h2 = self.embedding_l2(x)
            h = h1 + h2

        data.x = h

        if data.edge_index is None:
            edge_index = data.adj_t
        else:
            edge_index = data.edge_index

        if not hasattr(data, "edge_weight") or data.edge_weight is None:
            if not hasattr(data, "edge_attr") or data.edge_attr is None:
                raise NotImplementedError(
                    "Must offer edge_weight or edge_attr.")
            else:
                if data.edge_attr.shape[1] == 1:
                    data.edge_weight = data.edge_attr.reshape(-1)
                else:
                    data.edge_weight = torch.norm(data.edge_attr,
                                                  dim=1,
                                                  keepdim=True)

        h = self.bn(h)

        h = self.interactions(h,
                              edge_index,
                              data.edge_weight,
                              data.edge_attr,
                              data=data)

        h = self.internal_forward(h, z)

        out = self.readout_layer(h, batch)

        if self.add_state:
            assert hasattr(
                data, "state_attr"
            ), "the ``add_state`` must accept ``state_attr`` in data."
            sta = self.dp(data.state_attr)
            sta = self.ads(sta)
            sta = F.relu(sta)
            out = self.ads2(out.expand_as(sta) + sta)

        return self.output_forward(out)

    def get_res_interactions_layer(self, interactions):
        self.interactions = GeoResNet(interactions)

    @abstractmethod
    def get_interactions_layer(self):
        """This part shloud re-defined. And must add the ``interactions`` attribute.

        Examples::

            >>> ...
            >>> self.interactions = YourNet()
        """

    def get_readout_layer(self):
        """This part shloud re-defined. And must add the ``readout_layer`` attribute.

        Examples::

            >>> self.readout_layer = torch.nn.Sequential(...)

        Examples::

            >>> ...
            >>> self.readout_layer = YourNet()
        """
        if "readout_kwargs_layers_size" in self.readout_kwargs:
            self.readout_layer = GeneralReadOutLayer(**self.readout_kwargs)
        else:
            self.readout_layer = GeneralReadOutLayer([
                self.num_node_interaction_channels, self.readout, 200,
                self.out_size
            ], **self.readout_kwargs)

    def output_forward(self, out):

        if self.mean is not None and self.std is not None:
            out = out * self.std + self.mean

        if self.norm is True:
            out = torch.norm(out, dim=-1, keepdim=True)

        if self.scale is not None:
            out = self.scale * out

        return out.view(-1)

    def dipole_forward(self, h, z, pos, batch):
        # 加入偶极矩
        # Get center of mass.
        mass = self.atomic_mass[z].view(-1, 1)
        c = segment_csr(mass * pos, get_ptr(batch)) / segment_csr(
            mass, get_ptr(batch))
        h = h * (pos - c[batch])

        return h

    def internal_forward(self, h, z):
        if self.atom_ref is not None:
            h = h + self.atom_ref(z)

        return h

    def __repr__(self):
        return (
            f'{self.__class__.__name__}('
            f'num_node_hidden_channels={self.num_node_hidden_channels}, '
            f'num_node_interaction_channels={self.num_node_interaction_channels}, '
            f'num_interactions={self.num_interactions}, '
            f'num_edge_gaussians={self.num_edge_gaussians}, '
            f'cutoff={self.cutoff})')

    def reset_parameters(self):
        if hasattr(self, "embedding_e"):
            self.embedding_e.reset_parameters()
        if hasattr(self, "embedding_l"):
            self.embedding_l.reset_parameters()
        if hasattr(self, "embedding_l2"):
            self.embedding_l2.reset_parameters()
        self.bn.reset_parameters()
        if self.atom_ref is not None:
            self.atom_ref.weight.data.copy_(self.initial_atomref)