Exemple #1
0
class TGNMemory(torch.nn.Module):
    r"""The Temporal Graph Network (TGN) memory model from the
    `"Temporal Graph Networks for Deep Learning on Dynamic Graphs"
    <https://arxiv.org/abs/2006.10637>`_ paper.

    .. note::

        For an example of using TGN, see `examples/tgn.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        tgn.py>`_.

    Args:
        num_nodes (int): The number of nodes to save memories for.
        raw_msg_dim (int): The raw message dimensionality.
        memory_dim (int): The hidden memory dimensionality.
        time_dim (int): The time encoding dimensionality.
        message_module (torch.nn.Module): The message function which
            combines source and destination node memory embeddings, the raw
            message and the time encoding.
        aggregator_module (torch.nn.Module): The message aggregator function
            which aggregates messages to the same destination into a single
            representation.
    """
    def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int,
                 time_dim: int, message_module: Callable,
                 aggregator_module: Callable):
        super().__init__()

        self.num_nodes = num_nodes
        self.raw_msg_dim = raw_msg_dim
        self.memory_dim = memory_dim
        self.time_dim = time_dim

        self.msg_s_module = message_module
        self.msg_d_module = copy.deepcopy(message_module)
        self.aggr_module = aggregator_module
        self.time_enc = TimeEncoder(time_dim)
        self.gru = GRUCell(message_module.out_channels, memory_dim)

        self.register_buffer('memory', torch.empty(num_nodes, memory_dim))
        last_update = torch.empty(self.num_nodes, dtype=torch.long)
        self.register_buffer('last_update', last_update)
        self.register_buffer('__assoc__',
                             torch.empty(num_nodes, dtype=torch.long))

        self.msg_s_store = {}
        self.msg_d_store = {}

        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self.msg_s_module, 'reset_parameters'):
            self.msg_s_module.reset_parameters()
        if hasattr(self.msg_d_module, 'reset_parameters'):
            self.msg_d_module.reset_parameters()
        if hasattr(self.aggr_module, 'reset_parameters'):
            self.aggr_module.reset_parameters()
        self.time_enc.reset_parameters()
        self.gru.reset_parameters()
        self.reset_state()

    def reset_state(self):
        """Resets the memory to its initial state."""
        zeros(self.memory)
        zeros(self.last_update)
        self.__reset_message_store__()

    def detach(self):
        """Detachs the memory from gradient computation."""
        self.memory.detach_()

    def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
        """Returns, for all nodes :obj:`n_id`, their current memory and their
        last updated timestamp."""
        if self.training:
            memory, last_update = self.__get_updated_memory__(n_id)
        else:
            memory, last_update = self.memory[n_id], self.last_update[n_id]

        return memory, last_update

    def update_state(self, src, dst, t, raw_msg):
        """Updates the memory with newly encountered interactions
        :obj:`(src, dst, t, raw_msg)`."""
        n_id = torch.cat([src, dst]).unique()

        if self.training:
            self.__update_memory__(n_id)
            self.__update_msg_store__(src, dst, t, raw_msg, self.msg_s_store)
            self.__update_msg_store__(dst, src, t, raw_msg, self.msg_d_store)
        else:
            self.__update_msg_store__(src, dst, t, raw_msg, self.msg_s_store)
            self.__update_msg_store__(dst, src, t, raw_msg, self.msg_d_store)
            self.__update_memory__(n_id)

    def __reset_message_store__(self):
        i = self.memory.new_empty((0, ), dtype=torch.long)
        msg = self.memory.new_empty((0, self.raw_msg_dim))
        # Message store format: (src, dst, t, msg)
        self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
        self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}

    def __update_memory__(self, n_id):
        memory, last_update = self.__get_updated_memory__(n_id)
        self.memory[n_id] = memory
        self.last_update[n_id] = last_update

    def __get_updated_memory__(self, n_id):
        self.__assoc__[n_id] = torch.arange(n_id.size(0), device=n_id.device)

        # Compute messages (src -> dst).
        msg_s, t_s, src_s, dst_s = self.__compute_msg__(
            n_id, self.msg_s_store, self.msg_s_module)

        # Compute messages (dst -> src).
        msg_d, t_d, src_d, dst_d = self.__compute_msg__(
            n_id, self.msg_d_store, self.msg_d_module)

        # Aggregate messages.
        idx = torch.cat([src_s, src_d], dim=0)
        msg = torch.cat([msg_s, msg_d], dim=0)
        t = torch.cat([t_s, t_d], dim=0)
        aggr = self.aggr_module(msg, self.__assoc__[idx], t, n_id.size(0))

        # Get local copy of updated memory.
        memory = self.gru(aggr, self.memory[n_id])

        # Get local copy of updated `last_update`.
        dim_size = self.last_update.size(0)
        last_update = scatter_max(t, idx, dim=0, dim_size=dim_size)[0][n_id]

        return memory, last_update

    def __update_msg_store__(self, src, dst, t, raw_msg, msg_store):
        n_id, perm = src.sort()
        n_id, count = n_id.unique_consecutive(return_counts=True)
        for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):
            msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])

    def __compute_msg__(self, n_id, msg_store, msg_module):
        data = [msg_store[i] for i in n_id.tolist()]
        src, dst, t, raw_msg = list(zip(*data))
        src = torch.cat(src, dim=0)
        dst = torch.cat(dst, dim=0)
        t = torch.cat(t, dim=0)
        raw_msg = torch.cat(raw_msg, dim=0)
        t_rel = t - self.last_update[src]
        t_enc = self.time_enc(t_rel.to(raw_msg.dtype))

        msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc)

        return msg, t, src, dst

    def train(self, mode: bool = True):
        """Sets the module in training mode."""
        if self.training and not mode:
            # Flush message store to memory in case we just entered eval mode.
            self.__update_memory__(
                torch.arange(self.num_nodes, device=self.memory.device))
            self.__reset_message_store__()
        super().train(mode)
class AttentiveFP(torch.nn.Module):
    r"""The Attentive FP model for molecular representation learning from the
    `"Pushing the Boundaries of Molecular Representation for Drug Discovery
    with the Graph Attention Mechanism"
    <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on
    graph attention mechanisms.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Hidden node feature dimensionality.
        out_channels (int): Size of each output sample.
        edge_dim (int): Edge feature dimensionality.
        num_layers (int): Number of GNN layers.
        num_timesteps (int): Number of iterative refinement steps for global
            readout.
        dropout (float, optional): Dropout probability. (default: :obj:`0.0`)

    """
    def __init__(self,
                 in_channels: int,
                 hidden_channels: int,
                 out_channels: int,
                 edge_dim: int,
                 num_layers: int,
                 num_timesteps: int,
                 dropout: float = 0.0):
        super().__init__()

        self.num_layers = num_layers
        self.num_timesteps = num_timesteps
        self.dropout = dropout

        self.lin1 = Linear(in_channels, hidden_channels)

        conv = GATEConv(hidden_channels, hidden_channels, edge_dim, dropout)
        gru = GRUCell(hidden_channels, hidden_channels)
        self.atom_convs = torch.nn.ModuleList([conv])
        self.atom_grus = torch.nn.ModuleList([gru])
        for _ in range(num_layers - 1):
            conv = GATConv(hidden_channels,
                           hidden_channels,
                           dropout=dropout,
                           add_self_loops=False,
                           negative_slope=0.01)
            self.atom_convs.append(conv)
            self.atom_grus.append(GRUCell(hidden_channels, hidden_channels))

        self.mol_conv = GATConv(hidden_channels,
                                hidden_channels,
                                dropout=dropout,
                                add_self_loops=False,
                                negative_slope=0.01)
        self.mol_gru = GRUCell(hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, out_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin1.reset_parameters()
        for conv, gru in zip(self.atom_convs, self.atom_grus):
            conv.reset_parameters()
            gru.reset_parameters()
        self.mol_conv.reset_parameters()
        self.mol_gru.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x, edge_index, edge_attr, batch):
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.lin2(out)
Exemple #3
0
class PlainRNN(BaseModel):
	"""
	single_loss: single_loss= True, means the loss is
				  only caculated in the last time step.
	"""

	def __init__(self, inp_size,
					   hid_size,
					   out_size,
					   rnn_type="raw_rnn",
					   single_loss=True):
		super().__init__()
		allow_rnn_types = ["raw_rnn","lstm","gru"]
		assert rnn_type in allow_rnn_types
		self.rnn_type = rnn_type
		self.single_loss = single_loss
		self.inp_size = inp_size
		self.hid_size = hid_size
		self.out_size = out_size

		if rnn_type == "raw_rnn":
			self.lstm = RNNCell(inp_size, hid_size)
		if rnn_type == "lstm":
			self.lstm = LSTMCell(inp_size, hid_size)
		if rnn_type == "gru":
			self.lstm = GRUCell(inp_size, hid_size)
		self.fc1 = nn.Linear(hid_size, out_size)
		self.criterion = nn.CrossEntropyLoss()

	def init_weights(self):
		self.lstm.reset_parameters()
		self.fc1.reset_parameters()

	def init_states(self,batch_size):
		if self.rnn_type == "lstm":
			self.h = torch.zeros(batch_size, self.hid_size).to(device)
			self.c = torch.zeros(batch_size, self.hid_size).to(device)

		if self.rnn_type == "raw_rnn" or self.rnn_type == "gru":
			self.h = torch.zeros(batch_size, self.hid_size).to(device)


	def forward_train(self, x):
		assert len(x) == 2
		inp_x, inp_y = x
		inp_x = inp_x.to(device)
		inp_y = inp_y.to(device)
		batch, T, _ = inp_x.shape
		self.init_states(batch)

		dm_states = []
		loss = 0

		rr = torch.zeros((batch,self.out_size)).to(device)
		if self.rnn_type == "lstm":
			for i in range(T):
				y, (self.h,self.c) = self.lstm(inp_x[:,i],(self.h,self.c))
				output = self.fc1(y)
				rr.copy_(output)
				dm_states.append(rr.cpu().detach().numpy())

				if not self.single_loss:
					loss += self.criterion(output, inp_y.reshape(-1))
			if self.single_loss:
				loss += self.criterion(output, inp_y.reshape(-1))

		if self.rnn_type == "raw_rnn" or self.rnn_type == "gru":
			for i in range(T):
				self.h = self.lstm(inp_x[:,i],self.h)
				output = self.fc1(self.h)

				rr.copy_(output)
				dm_states.append(rr.cpu().detach().numpy())

				if not self.single_loss:
					loss += self.criterion(output, inp_y.reshape(-1))
			if self.single_loss:
				loss += self.criterion(output, inp_y.reshape(-1))

		print("loss is ",loss.cpu().item())
		outputs = dict(
					 loss = loss,
					 outputs = dm_states
					  )
		return outputs


	def forward_test(self,x):
		# x, new_state = self.lstm(x, state)
		# x = self.fc1(x
		assert len(x) == 2
		inp_x, inp_y = x
		inp_x = inp_x.to(device)
		batch, T, _ = inp_x.shape
		self.init_states(batch)

		dm_states = []

		rr = torch.zeros((batch,self.out_size)).to(device)
		if self.rnn_type == "lstm":
			for i in range(T):
				y, (self.h,self.c) = self.lstm(inp_x[:,i],(self.h,self.c))
				output = self.fc1(y)

				rr.copy_(output)
				dm_states.append(rr.cpu().detach().numpy())

		if self.rnn_type == "raw_rnn" or self.rnn_type == "gru":
			for i in range(T):
				self.h = self.lstm(inp_x[:,i],self.h)
				output = self.fc1(self.h)

				rr.copy_(output)
				dm_states.append(rr.cpu().detach().numpy())

		if self.single_loss:
			dm_states = (np.array(dm_states)[-1]).reshape(1,batch,-1)
		else:
			# dm_states = (np.array(dm_states)).reshape(T,batch,-1)
			dm_states = np.array(dm_states).reshape(T*batch,-1)
			dm_states_ = np.zeros_like(dm_states)
			index = np.argmax(dm_states,axis=1)
			dm_states_[range(T*batch),index] = 1
			dm_states = dm_states_.reshape(T,batch,-1).mean(axis=0).reshape(1,batch,-1)


		inp_y = inp_y.view(-1).cpu().numpy()

		outputs = dict(
					 outputs = dm_states,
					 labels = inp_y
					  )
		return outputs
Exemple #4
0
class AttentiveFP(torch.nn.Module):
    r"""The Attentive FP model for molecular representation learning from the
    `"Pushing the Boundaries of Molecular Representation for Drug Discovery
    with the Graph Attention Mechanism"
    <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on
    graph attention mechanisms.

    Args:
        emb_dim (int): Hidden node feature dimensionality.
        num_tasks (int): Size of each output sample.
        num_layers (int): Number of GNN layers.
        num_timesteps (int): Number of iterative refinement steps for global
            readout.
        drop_ratio (float, optional): Dropout probability. (default: :obj:`0.0`)

    """
    def __init__(self,
                 num_timesteps=4,
                 emb_dim=300,
                 num_layers=5,
                 drop_ratio=0,
                 num_tasks=1,
                 **args):
        super(AttentiveFP, self).__init__()

        self.num_layers = num_layers
        self.num_timesteps = num_timesteps
        self.drop_ratio = drop_ratio

        self.atom_encoder = AtomEncoder(emb_dim)
        self.bond_encoder = BondEncoder(emb_dim=emb_dim)

        conv = GATEConv(emb_dim, emb_dim, emb_dim, drop_ratio)
        gru = GRUCell(emb_dim, emb_dim)
        self.atom_convs = torch.nn.ModuleList([conv])
        self.atom_grus = torch.nn.ModuleList([gru])
        for _ in range(num_layers - 1):
            conv = GATConv(emb_dim,
                           emb_dim,
                           dropout=drop_ratio,
                           add_self_loops=False,
                           negative_slope=0.01)
            self.atom_convs.append(conv)
            self.atom_grus.append(GRUCell(emb_dim, emb_dim))

        self.mol_conv = GATConv(emb_dim,
                                emb_dim,
                                dropout=drop_ratio,
                                add_self_loops=False,
                                negative_slope=0.01)
        self.mol_gru = GRUCell(emb_dim, emb_dim)

        self.graph_pred_linear = Linear(emb_dim, num_tasks)

        self.reset_parameters()

    def reset_parameters(self):
        # self.atom_encoder.reset_parameters() # reset in init()
        # self.bond_encoder.reset_parameters() # reset in init()
        for conv, gru in zip(self.atom_convs, self.atom_grus):
            conv.reset_parameters()
            gru.reset_parameters()
        self.mol_conv.reset_parameters()
        self.mol_gru.reset_parameters()
        self.graph_pred_linear.reset_parameters()

    def forward(self, batched_data):
        """"""
        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
        # Atom Embedding:
        x = F.leaky_relu_(self.atom_encoder(x))
        edge_attr = self.bond_encoder(edge_attr)

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.drop_ratio, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.drop_ratio, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.drop_ratio, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.drop_ratio, training=self.training)
        return self.graph_pred_linear(out)