class TGNMemory(torch.nn.Module): r"""The Temporal Graph Network (TGN) memory model from the `"Temporal Graph Networks for Deep Learning on Dynamic Graphs" <https://arxiv.org/abs/2006.10637>`_ paper. .. note:: For an example of using TGN, see `examples/tgn.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ tgn.py>`_. Args: num_nodes (int): The number of nodes to save memories for. raw_msg_dim (int): The raw message dimensionality. memory_dim (int): The hidden memory dimensionality. time_dim (int): The time encoding dimensionality. message_module (torch.nn.Module): The message function which combines source and destination node memory embeddings, the raw message and the time encoding. aggregator_module (torch.nn.Module): The message aggregator function which aggregates messages to the same destination into a single representation. """ def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable): super().__init__() self.num_nodes = num_nodes self.raw_msg_dim = raw_msg_dim self.memory_dim = memory_dim self.time_dim = time_dim self.msg_s_module = message_module self.msg_d_module = copy.deepcopy(message_module) self.aggr_module = aggregator_module self.time_enc = TimeEncoder(time_dim) self.gru = GRUCell(message_module.out_channels, memory_dim) self.register_buffer('memory', torch.empty(num_nodes, memory_dim)) last_update = torch.empty(self.num_nodes, dtype=torch.long) self.register_buffer('last_update', last_update) self.register_buffer('__assoc__', torch.empty(num_nodes, dtype=torch.long)) self.msg_s_store = {} self.msg_d_store = {} self.reset_parameters() def reset_parameters(self): if hasattr(self.msg_s_module, 'reset_parameters'): self.msg_s_module.reset_parameters() if hasattr(self.msg_d_module, 'reset_parameters'): self.msg_d_module.reset_parameters() if hasattr(self.aggr_module, 'reset_parameters'): self.aggr_module.reset_parameters() self.time_enc.reset_parameters() self.gru.reset_parameters() self.reset_state() def reset_state(self): """Resets the memory to its initial state.""" zeros(self.memory) zeros(self.last_update) self.__reset_message_store__() def detach(self): """Detachs the memory from gradient computation.""" self.memory.detach_() def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]: """Returns, for all nodes :obj:`n_id`, their current memory and their last updated timestamp.""" if self.training: memory, last_update = self.__get_updated_memory__(n_id) else: memory, last_update = self.memory[n_id], self.last_update[n_id] return memory, last_update def update_state(self, src, dst, t, raw_msg): """Updates the memory with newly encountered interactions :obj:`(src, dst, t, raw_msg)`.""" n_id = torch.cat([src, dst]).unique() if self.training: self.__update_memory__(n_id) self.__update_msg_store__(src, dst, t, raw_msg, self.msg_s_store) self.__update_msg_store__(dst, src, t, raw_msg, self.msg_d_store) else: self.__update_msg_store__(src, dst, t, raw_msg, self.msg_s_store) self.__update_msg_store__(dst, src, t, raw_msg, self.msg_d_store) self.__update_memory__(n_id) def __reset_message_store__(self): i = self.memory.new_empty((0, ), dtype=torch.long) msg = self.memory.new_empty((0, self.raw_msg_dim)) # Message store format: (src, dst, t, msg) self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)} self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)} def __update_memory__(self, n_id): memory, last_update = self.__get_updated_memory__(n_id) self.memory[n_id] = memory self.last_update[n_id] = last_update def __get_updated_memory__(self, n_id): self.__assoc__[n_id] = torch.arange(n_id.size(0), device=n_id.device) # Compute messages (src -> dst). msg_s, t_s, src_s, dst_s = self.__compute_msg__( n_id, self.msg_s_store, self.msg_s_module) # Compute messages (dst -> src). msg_d, t_d, src_d, dst_d = self.__compute_msg__( n_id, self.msg_d_store, self.msg_d_module) # Aggregate messages. idx = torch.cat([src_s, src_d], dim=0) msg = torch.cat([msg_s, msg_d], dim=0) t = torch.cat([t_s, t_d], dim=0) aggr = self.aggr_module(msg, self.__assoc__[idx], t, n_id.size(0)) # Get local copy of updated memory. memory = self.gru(aggr, self.memory[n_id]) # Get local copy of updated `last_update`. dim_size = self.last_update.size(0) last_update = scatter_max(t, idx, dim=0, dim_size=dim_size)[0][n_id] return memory, last_update def __update_msg_store__(self, src, dst, t, raw_msg, msg_store): n_id, perm = src.sort() n_id, count = n_id.unique_consecutive(return_counts=True) for i, idx in zip(n_id.tolist(), perm.split(count.tolist())): msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx]) def __compute_msg__(self, n_id, msg_store, msg_module): data = [msg_store[i] for i in n_id.tolist()] src, dst, t, raw_msg = list(zip(*data)) src = torch.cat(src, dim=0) dst = torch.cat(dst, dim=0) t = torch.cat(t, dim=0) raw_msg = torch.cat(raw_msg, dim=0) t_rel = t - self.last_update[src] t_enc = self.time_enc(t_rel.to(raw_msg.dtype)) msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc) return msg, t, src, dst def train(self, mode: bool = True): """Sets the module in training mode.""" if self.training and not mode: # Flush message store to memory in case we just entered eval mode. self.__update_memory__( torch.arange(self.num_nodes, device=self.memory.device)) self.__reset_message_store__() super().train(mode)
class AttentiveFP(torch.nn.Module): r"""The Attentive FP model for molecular representation learning from the `"Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism" <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on graph attention mechanisms. Args: in_channels (int): Size of each input sample. hidden_channels (int): Hidden node feature dimensionality. out_channels (int): Size of each output sample. edge_dim (int): Edge feature dimensionality. num_layers (int): Number of GNN layers. num_timesteps (int): Number of iterative refinement steps for global readout. dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, edge_dim: int, num_layers: int, num_timesteps: int, dropout: float = 0.0): super().__init__() self.num_layers = num_layers self.num_timesteps = num_timesteps self.dropout = dropout self.lin1 = Linear(in_channels, hidden_channels) conv = GATEConv(hidden_channels, hidden_channels, edge_dim, dropout) gru = GRUCell(hidden_channels, hidden_channels) self.atom_convs = torch.nn.ModuleList([conv]) self.atom_grus = torch.nn.ModuleList([gru]) for _ in range(num_layers - 1): conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(hidden_channels, hidden_channels)) self.mol_conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.mol_gru = GRUCell(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.reset_parameters() def reset_parameters(self): self.lin1.reset_parameters() for conv, gru in zip(self.atom_convs, self.atom_grus): conv.reset_parameters() gru.reset_parameters() self.mol_conv.reset_parameters() self.mol_gru.reset_parameters() self.lin2.reset_parameters() def forward(self, x, edge_index, edge_attr, batch): """""" # Atom Embedding: x = F.leaky_relu_(self.lin1(x)) h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr)) h = F.dropout(h, p=self.dropout, training=self.training) x = self.atom_grus[0](h, x).relu_() for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]): h = F.elu_(conv(x, edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) x = gru(h, x).relu_() # Molecule Embedding: row = torch.arange(batch.size(0), device=batch.device) edge_index = torch.stack([row, batch], dim=0) out = global_add_pool(x, batch).relu_() for t in range(self.num_timesteps): h = F.elu_(self.mol_conv((x, out), edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) out = self.mol_gru(h, out).relu_() # Predictor: out = F.dropout(out, p=self.dropout, training=self.training) return self.lin2(out)
class PlainRNN(BaseModel): """ single_loss: single_loss= True, means the loss is only caculated in the last time step. """ def __init__(self, inp_size, hid_size, out_size, rnn_type="raw_rnn", single_loss=True): super().__init__() allow_rnn_types = ["raw_rnn","lstm","gru"] assert rnn_type in allow_rnn_types self.rnn_type = rnn_type self.single_loss = single_loss self.inp_size = inp_size self.hid_size = hid_size self.out_size = out_size if rnn_type == "raw_rnn": self.lstm = RNNCell(inp_size, hid_size) if rnn_type == "lstm": self.lstm = LSTMCell(inp_size, hid_size) if rnn_type == "gru": self.lstm = GRUCell(inp_size, hid_size) self.fc1 = nn.Linear(hid_size, out_size) self.criterion = nn.CrossEntropyLoss() def init_weights(self): self.lstm.reset_parameters() self.fc1.reset_parameters() def init_states(self,batch_size): if self.rnn_type == "lstm": self.h = torch.zeros(batch_size, self.hid_size).to(device) self.c = torch.zeros(batch_size, self.hid_size).to(device) if self.rnn_type == "raw_rnn" or self.rnn_type == "gru": self.h = torch.zeros(batch_size, self.hid_size).to(device) def forward_train(self, x): assert len(x) == 2 inp_x, inp_y = x inp_x = inp_x.to(device) inp_y = inp_y.to(device) batch, T, _ = inp_x.shape self.init_states(batch) dm_states = [] loss = 0 rr = torch.zeros((batch,self.out_size)).to(device) if self.rnn_type == "lstm": for i in range(T): y, (self.h,self.c) = self.lstm(inp_x[:,i],(self.h,self.c)) output = self.fc1(y) rr.copy_(output) dm_states.append(rr.cpu().detach().numpy()) if not self.single_loss: loss += self.criterion(output, inp_y.reshape(-1)) if self.single_loss: loss += self.criterion(output, inp_y.reshape(-1)) if self.rnn_type == "raw_rnn" or self.rnn_type == "gru": for i in range(T): self.h = self.lstm(inp_x[:,i],self.h) output = self.fc1(self.h) rr.copy_(output) dm_states.append(rr.cpu().detach().numpy()) if not self.single_loss: loss += self.criterion(output, inp_y.reshape(-1)) if self.single_loss: loss += self.criterion(output, inp_y.reshape(-1)) print("loss is ",loss.cpu().item()) outputs = dict( loss = loss, outputs = dm_states ) return outputs def forward_test(self,x): # x, new_state = self.lstm(x, state) # x = self.fc1(x assert len(x) == 2 inp_x, inp_y = x inp_x = inp_x.to(device) batch, T, _ = inp_x.shape self.init_states(batch) dm_states = [] rr = torch.zeros((batch,self.out_size)).to(device) if self.rnn_type == "lstm": for i in range(T): y, (self.h,self.c) = self.lstm(inp_x[:,i],(self.h,self.c)) output = self.fc1(y) rr.copy_(output) dm_states.append(rr.cpu().detach().numpy()) if self.rnn_type == "raw_rnn" or self.rnn_type == "gru": for i in range(T): self.h = self.lstm(inp_x[:,i],self.h) output = self.fc1(self.h) rr.copy_(output) dm_states.append(rr.cpu().detach().numpy()) if self.single_loss: dm_states = (np.array(dm_states)[-1]).reshape(1,batch,-1) else: # dm_states = (np.array(dm_states)).reshape(T,batch,-1) dm_states = np.array(dm_states).reshape(T*batch,-1) dm_states_ = np.zeros_like(dm_states) index = np.argmax(dm_states,axis=1) dm_states_[range(T*batch),index] = 1 dm_states = dm_states_.reshape(T,batch,-1).mean(axis=0).reshape(1,batch,-1) inp_y = inp_y.view(-1).cpu().numpy() outputs = dict( outputs = dm_states, labels = inp_y ) return outputs
class AttentiveFP(torch.nn.Module): r"""The Attentive FP model for molecular representation learning from the `"Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism" <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on graph attention mechanisms. Args: emb_dim (int): Hidden node feature dimensionality. num_tasks (int): Size of each output sample. num_layers (int): Number of GNN layers. num_timesteps (int): Number of iterative refinement steps for global readout. drop_ratio (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__(self, num_timesteps=4, emb_dim=300, num_layers=5, drop_ratio=0, num_tasks=1, **args): super(AttentiveFP, self).__init__() self.num_layers = num_layers self.num_timesteps = num_timesteps self.drop_ratio = drop_ratio self.atom_encoder = AtomEncoder(emb_dim) self.bond_encoder = BondEncoder(emb_dim=emb_dim) conv = GATEConv(emb_dim, emb_dim, emb_dim, drop_ratio) gru = GRUCell(emb_dim, emb_dim) self.atom_convs = torch.nn.ModuleList([conv]) self.atom_grus = torch.nn.ModuleList([gru]) for _ in range(num_layers - 1): conv = GATConv(emb_dim, emb_dim, dropout=drop_ratio, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(emb_dim, emb_dim)) self.mol_conv = GATConv(emb_dim, emb_dim, dropout=drop_ratio, add_self_loops=False, negative_slope=0.01) self.mol_gru = GRUCell(emb_dim, emb_dim) self.graph_pred_linear = Linear(emb_dim, num_tasks) self.reset_parameters() def reset_parameters(self): # self.atom_encoder.reset_parameters() # reset in init() # self.bond_encoder.reset_parameters() # reset in init() for conv, gru in zip(self.atom_convs, self.atom_grus): conv.reset_parameters() gru.reset_parameters() self.mol_conv.reset_parameters() self.mol_gru.reset_parameters() self.graph_pred_linear.reset_parameters() def forward(self, batched_data): """""" x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch # Atom Embedding: x = F.leaky_relu_(self.atom_encoder(x)) edge_attr = self.bond_encoder(edge_attr) h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr)) h = F.dropout(h, p=self.drop_ratio, training=self.training) x = self.atom_grus[0](h, x).relu_() for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]): h = F.elu_(conv(x, edge_index)) h = F.dropout(h, p=self.drop_ratio, training=self.training) x = gru(h, x).relu_() # Molecule Embedding: row = torch.arange(batch.size(0), device=batch.device) edge_index = torch.stack([row, batch], dim=0) out = global_add_pool(x, batch).relu_() for t in range(self.num_timesteps): h = F.elu_(self.mol_conv((x, out), edge_index)) h = F.dropout(h, p=self.drop_ratio, training=self.training) out = self.mol_gru(h, out).relu_() # Predictor: out = F.dropout(out, p=self.drop_ratio, training=self.training) return self.graph_pred_linear(out)