def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, edge_dim: int, num_layers: int, num_timesteps: int, dropout: float = 0.0): super(AttentiveFP, self).__init__() self.num_layers = num_layers self.num_timesteps = num_timesteps self.dropout = dropout self.lin1 = Linear(in_channels, hidden_channels) conv = GATEConv(hidden_channels, hidden_channels, edge_dim, dropout) gru = GRUCell(hidden_channels, hidden_channels) self.atom_convs = torch.nn.ModuleList([conv]) self.atom_grus = torch.nn.ModuleList([gru]) for _ in range(num_layers - 1): conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(hidden_channels, hidden_channels)) self.mol_conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.mol_gru = GRUCell(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.reset_parameters()
def __init__(self, word_vec_dim, hidden_size, sesssion_num_layers=2, user_num_layers=2, dropout_p=0.1): super(HierarchicalRNNCell, self).__init__() self.word_vec_dim = word_vec_dim self.hidden_size = hidden_size self.session_num_layers = sesssion_num_layers self.user_num_layers = user_num_layers self.dropout_p = dropout_p self.session_rnn_cells = [ GRUCell(input_size=word_vec_dim, hidden_size=hidden_size) ] self.session_rnn_cells = self.session_rnn_cells + [ GRUCell(input_size=hidden_size, hidden_size=hidden_size) for _ in range(sesssion_num_layers - 1) ] self.session_rnn_cells = ListModule(*self.session_rnn_cells) self.user_rnn_cells = [ GRUCell(input_size=hidden_size, hidden_size=hidden_size) for _ in range(user_num_layers) ] self.user_rnn_cells = ListModule(*self.user_rnn_cells) self.dropout = nn.Dropout(p=dropout_p)
def __init__(self, node_in_channels, node_out_channels, edge_in_channels, edge_out_channels, heads=5, dropout=0.1, parameter_efficient=True, attention_channels=None, principal_neighbourhood_aggregation=False, deg=None, aggr='add'): super().__init__() self.conv = GATEConv(node_in_channels=node_in_channels, node_out_channels=node_out_channels, edge_in_channels=edge_in_channels, edge_out_channels=edge_out_channels, heads=heads, concat=True, attention_channels=attention_channels, parameter_efficient=parameter_efficient, principal_neighbourhood_aggregation= principal_neighbourhood_aggregation, deg=deg, aggr=aggr) self.node_norm = BatchNorm(node_in_channels) self.edge_norm = BatchNorm(edge_in_channels) self.node_aggr = GRUCell(node_in_channels, node_out_channels) self.edge_aggr = GRUCell(edge_in_channels, edge_out_channels) self.dropout = dropout self.edge_channels = edge_in_channels
def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable): super().__init__() self.num_nodes = num_nodes self.raw_msg_dim = raw_msg_dim self.memory_dim = memory_dim self.time_dim = time_dim self.msg_s_module = message_module self.msg_d_module = copy.deepcopy(message_module) self.aggr_module = aggregator_module self.time_enc = TimeEncoder(time_dim) self.gru = GRUCell(message_module.out_channels, memory_dim) self.register_buffer('memory', torch.empty(num_nodes, memory_dim)) last_update = torch.empty(self.num_nodes, dtype=torch.long) self.register_buffer('last_update', last_update) self.register_buffer('__assoc__', torch.empty(num_nodes, dtype=torch.long)) self.msg_s_store = {} self.msg_d_store = {} self.reset_parameters()
def __init__(self, in_features, memory_dim, r, eps=0): super(Decoder, self).__init__() self.max_decoder_steps = 200 self.memory_dim = memory_dim self.eps = eps self.r = r self.prenet = Prenet(in_features=(self.memory_dim * self.r), out_features=[256, 128]) self.attention_rnn = AttentionRNN(out_dim=256, annot_dim=in_features, memory_dim=128) self.project_to_decoder_in = Linear(in_features=(256 + in_features), out_features=256) # 2-layer residual GRU (256 cells) self.decoder_rnns = ModuleList( [GRUCell(input_size=256, hidden_size=256), GRUCell(input_size=256, hidden_size=256)]) self.proj_to_mel = Linear(in_features=256, out_features=(self.memory_dim * self.r))
def __init__(self, c_dim: int, m_dim: int, p_dim: int, radius=2, use_cuda=False, dropout=0., use_gru=True): super(AlignAttendPooling, self).__init__() self.use_cuda = use_cuda self.use_gru = use_gru self.radius = radius self.map = Linear(c_dim, m_dim) self.relu = LeakyReLU() self.relu1 = LeakyReLU() if use_gru: self.gru = GRUCell(c_dim, m_dim) else: self.linear = Linear(c_dim + m_dim, m_dim) self.attend = Linear(c_dim + p_dim, c_dim) self.align = Linear(m_dim + c_dim + p_dim, 1) self.softmax = Softmax(dim=1) self.elu = ELU() self.relu2 = ReLU() self.dropout = Dropout(p=dropout)
def __init__(self, input_dim, contxt_dim, hidden_dim, att_module, pointer_gen_module, cat_contx_to_inp=True, pass_extra_feat_to_pg=False, copy_prob=None, **kwargs): """ :param pass_extra_feat_to_pg: whether to pass (concatenated) extra features to the pointer-generator network or just embeddings. """ super(GruPointerDecoder, self).__init__() self.input_dim = input_dim self.contxt_dim = contxt_dim self.hidden_dim = hidden_dim self.att = att_module self.cat_conxt_to_inp = cat_contx_to_inp cell_inp_dim = input_dim + contxt_dim if cat_contx_to_inp else input_dim self.gru_cell = GRUCell(cell_inp_dim, hidden_dim, **kwargs) self.pgn = pointer_gen_module self.copy_prob = copy_prob self.pass_extra_feat_to_pg = pass_extra_feat_to_pg
def __init__(self, cnn_channels, cnn_dropout, rnn_in_dim, rnn_out_dim, rnn_dropout, nb_classes): """The CRNN model. :param cnn_channels: The amount of CNN channels. :type cnn_channels: int :param cnn_dropout: The dropout to be applied to the CNNs. :type cnn_dropout: float :param rnn_in_dim: The input dimensionality of the RNN. :type rnn_in_dim: int :param rnn_out_dim: The output dimensionality of the RNN. :type rnn_out_dim: int :param rnn_dropout: The dropout to be applied to the RNN. :type rnn_dropout: float :param nb_classes: The amount of classes to be predicted. :type nb_classes: int """ super(CRNN, self).__init__() self.dnn_output_features = cnn_channels self.rnn_hh_size = rnn_out_dim self.nb_classes = nb_classes self.dnn = Sequential( dnn.DNN(cnn_channels=cnn_channels, cnn_dropout=cnn_dropout), Dropout(rnn_dropout) ) self.rnn = GRUCell(rnn_in_dim, self.rnn_hh_size, bias=True) self.classifier = Linear(self.rnn_hh_size, self.nb_classes, bias=True)
def __init__(self, cnn_channels: int, cnn_dropout: float, rnn_in_dim: int, rnn_out_dim: int, nb_classes: int) \ -> None: """The DESSED model. :param cnn_channels: Amount of CNN channels. :type cnn_channels: int :param cnn_dropout: Dropout to be applied to the CNNs. :type cnn_dropout: float :param rnn_in_dim: Input dimensionality of the RNN. :type rnn_in_dim: int :param rnn_out_dim: Output dimensionality of the RNN. :type rnn_out_dim: int :param nb_classes: Amount of classes to be predicted. :type nb_classes: int """ super().__init__() self.rnn_hh_size: int = rnn_out_dim self.nb_classes: int = nb_classes self.dnn: Module = DepthWiseSeparableDNN(cnn_channels=cnn_channels, cnn_dropout=cnn_dropout) self.rnn: Module = GRUCell(input_size=rnn_in_dim, hidden_size=self.rnn_hh_size, bias=True) self.classifier: Module = Linear(in_features=self.rnn_hh_size, out_features=self.nb_classes, bias=True)
def __init__(self, arch: dict): super().__init__() self.ldim = arch['latent_dim'] self.defaultSteps = arch['std_T'] self.Cinit = torch.nn.Parameter(torch.FloatTensor(self.ldim)) torch.nn.init.normal_(self.Cinit) self.Linit = torch.nn.Parameter(torch.FloatTensor(self.ldim)) torch.nn.init.normal_(self.Linit) self.rec_block = arch['recurrent_block'] if self.rec_block == 'test': # self.block = BiPartialTestBlock(self.ldim, self.ldim, 2 * self.ldim, 2) Ldim, Cdim, Hdim, Dpth = self.ldim, self.ldim, 2 * self.ldim, 2 self.Cmsg = batchMLP(Cdim, Hdim, Cdim, Dpth, False) self.Lmsg = batchMLP(Ldim, Hdim, Ldim, Dpth, False) self.Cu = batchMLP(Cdim * 2, Hdim, Cdim, Dpth, False) self.Lu = batchMLP(Ldim * 3, Hdim, Ldim, Dpth, False) elif self.rec_block in ['std_lstm', 'ln_lstm', 'gru']: Ldim, Cdim, Hdim, Dpth = self.ldim, self.ldim, self.ldim, 4 self.Cmsg = batchMLP(Cdim, Hdim, Cdim, Dpth, False) self.Lmsg = batchMLP(Ldim, Hdim, Ldim, Dpth, False) if self.rec_block == 'std_lstm': self.Cu = LSTMCell(self.ldim, self.ldim, True) self.Lu = LSTMCell(self.ldim * 2, self.ldim, True) elif self.rec_block == 'ln_lstm': self.Cu = ln_LSTMCell(self.ldim, self.ldim, True) self.Lu = ln_LSTMCell(self.ldim * 2, self.ldim, True) elif self.rec_block == 'gru': self.Cu = GRUCell(self.ldim, self.ldim, True) self.Lu = GRUCell(self.ldim * 2, self.ldim, True) self.cl = arch['classifier'] if self.cl == 'NeuroSAT': self.Lvote = batchMLP(self.ldim, 2 * self.ldim, 1, 2, False) elif self.cl == 'CircuitSAT-like': self.tnormf = arch['tnorm'] if 'tnorm_train' in arch: self.train_tnorm = arch['tnorm_train'] else: self.train_tnorm = self.tnormf self.tnorm_tmp = arch['tnorm_temperature'] self.train_temp = arch['temp_train'] self.test_temp = arch['temp_test'] self.Lvote = batchMLP(self.ldim, 2 * self.ldim, 1, 2, False)
def __init__(self, num_L): super(DrBC, self).__init__() self.num_L = num_L self.data_in = Sequential(Linear(3, 128), ReLU()) self.Aggregation = GCNConv(128, 128) self.Combine = GRUCell(128, 128, bias=False) self.data_out = Sequential(Linear(128, 64), ReLU(), Linear(64,1), ReLU())
def __init__(self, input_dim, hidden_dim, dropout_prob=0., hidden_norm=False, **kwargs): super(GruMaskedEncoder, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.dropout_prob = dropout_prob self.gru_cell = GRUCell(input_dim, hidden_dim, **kwargs) self.dropout_layer = Dropout(dropout_prob) self.hidden_norm = LayerNorm(hidden_dim) if hidden_norm else None
def get_pytorch_gru(input_size, used_gru): """Load in a PyTorch GRU that is a copy of the currently used GRU.""" gru = PytorchGRU(input_size, 1) if type(used_gru) == GRUBerkeley: gru.bias_hh[:] = tensor(zeros((3,)), dtype=float64)[:] gru.bias_ih[:] = tensor(used_gru.bias, dtype=float64)[:] gru.weight_hh[:] = tensor(used_gru.weight_hh, dtype=float64)[:] gru.weight_ih[:] = tensor(used_gru.weight_xh, dtype=float64)[:] elif type(used_gru) == GRUPyTorch: gru.bias_hh[:] = tensor(used_gru.bias_hh, dtype=float64)[:] gru.bias_ih[:] = tensor(used_gru.bias_ih, dtype=float64)[:] gru.weight_hh[:] = tensor(used_gru.weight_hh, dtype=float64)[:] gru.weight_ih[:] = tensor(used_gru.weight_ih, dtype=float64)[:] else: raise Exception(f"Invalid input for used_gru: {used_gru}") return gru
def __init__(self,params:configargparse.Namespace,att: torch.nn.Module=None): """ Neural Network Module for the Sequence to Sequence LAS Model :params configargparse.Namespace params: The training options :params torch.nn.Module att: The attention module """ super(Speller,self).__init__() ## Embedding Layer self.embed = Embedding(params.odim,params.demb_dim) ## Decoder with LSTM Cells self.decoder = ModuleList() self.dropout_dec = ModuleList() self.dtype = params.dtype self.dunits = params.dhiddens self.dlayers = params.dlayers self.decoder += [ LSTMCell(params.eprojs + params.demb_dim, params.dhiddens) if self.dtype == "lstm" else GRUCell(params.eprojs + params.demb_dim, params.dhiddens) ] self.dropout_dec += [Dropout(p=params.ddropout)] self.dropout_emb = Dropout(p=params.ddropout) ## Other decoder layers if > 1 decoder layer for i in range(1,params.dlayers): self.decoder += [ LSTMCell(params.dhiddens, params.dhiddens) if self.dtype == "lstm" else GRUCell(params.dhiddens, params.dhiddens) ] self.dropout_dec += [LockedDropout(p=params.ddropout)] # Dropout ## Project to Softmax Space- Output self.projections = Linear(params.dhiddens, params.odim) ## Attention Module self.att = att ## Scheduled Sampling self.sampling_probability = params.ssprob ## Initialize EOS, SOS self.eos = len(params.char_list) -1 self.sos = self.eos self.ignore_id = params.text_pad
def __init__(self, input_dim, context_length, debug): """The RNN encoder of the Masker. :param input_dim: The input dimensionality. :type input_dim: int :param context_length: The context length. :type context_length: int :param debug: Flag to indicate debug :type debug: bool """ super(RNNEnc, self).__init__() self._input_dim = input_dim self._context_length = context_length self.gru_enc_f = GRUCell(self._input_dim, self._input_dim) self.gru_enc_b = GRUCell(self._input_dim, self._input_dim) self._debug = debug self.initialize_encoder()
def __init__( self, input_size: int, hidden_size: int, ) -> None: super().__init__() # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. self._generator = GRUCell(input_size, hidden_size) self._projection = nn.Linear(2*hidden_size, hidden_size)
def __init__(self, inp_size, hid_size, out_size, rnn_type="raw_rnn", single_loss=True): super().__init__() allow_rnn_types = ["raw_rnn","lstm","gru"] assert rnn_type in allow_rnn_types self.rnn_type = rnn_type self.single_loss = single_loss self.inp_size = inp_size self.hid_size = hid_size self.out_size = out_size if rnn_type == "raw_rnn": self.lstm = RNNCell(inp_size, hid_size) if rnn_type == "lstm": self.lstm = LSTMCell(inp_size, hid_size) if rnn_type == "gru": self.lstm = GRUCell(inp_size, hid_size) self.fc1 = nn.Linear(hid_size, out_size) self.criterion = nn.CrossEntropyLoss()
def __init__(self, num_timesteps=4, emb_dim=300, num_layers=5, drop_ratio=0, num_tasks=1, **args): super(AttentiveFP, self).__init__() self.num_layers = num_layers self.num_timesteps = num_timesteps self.drop_ratio = drop_ratio self.atom_encoder = AtomEncoder(emb_dim) self.bond_encoder = BondEncoder(emb_dim=emb_dim) conv = GATEConv(emb_dim, emb_dim, emb_dim, drop_ratio) gru = GRUCell(emb_dim, emb_dim) self.atom_convs = torch.nn.ModuleList([conv]) self.atom_grus = torch.nn.ModuleList([gru]) for _ in range(num_layers - 1): conv = GATConv(emb_dim, emb_dim, dropout=drop_ratio, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(emb_dim, emb_dim)) self.mol_conv = GATConv(emb_dim, emb_dim, dropout=drop_ratio, add_self_loops=False, negative_slope=0.01) self.mol_gru = GRUCell(emb_dim, emb_dim) self.graph_pred_linear = Linear(emb_dim, num_tasks) self.reset_parameters()
def __init__(self, voc_size, hidden_size, device, num_layers=2): super().__init__() self.encoder = Encoder(voc_size, hidden_size, num_layers) self.gru = GRU(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers) self.grucell_list = [ GRUCell(input_size=hidden_size, hidden_size=hidden_size).to(device) for i in range(num_layers) ] self.score = Linear(in_features=hidden_size * num_layers, out_features=1) self.num_layers = num_layers self.hidden_size = hidden_size self.device = device
def __init__(self, input_dim, debug): """The RNN dec of the Masker. :param input_dim: The input dimensionality. :type input_dim: int :param debug: Flag to indicate debug :type debug: bool """ super(RNNDec, self).__init__() self._input_dim = input_dim self.gru_dec = GRUCell(self._input_dim, self._input_dim) self._debug = debug self.initialize_decoder()
def __init__(self, cnn_channels, cnn_dropout, rnn_in_dim, rnn_out_dim, rnn_dropout, nb_classes, batch_counter, gamma_factor, mul_factor, min_prob, max_prob): """The Sound Event Detection (SED) model with teacher forcing and\ scheduled sampling. :param cnn_channels: The amount of CNN channels for the SED model. :type cnn_channels: int :param cnn_dropout: The dropout percentage for the CNNs dropout. :type cnn_dropout: float :param rnn_in_dim: The input dimensionality for the RNNs. :type rnn_in_dim: int :param rnn_out_dim: The output dimensionality for the RNNs. :type rnn_out_dim: int :param rnn_dropout: The dropout percentage for the RNN dropout. :type rnn_dropout: float :param nb_classes: The amount of output classes. :type nb_classes: int :param batch_counter: The amount of batches in one full epoch. :type batch_counter: int :param gamma_factor: The gamma factor for scheduled sampling. :type gamma_factor: float :param mul_factor: The multiplication factor for scheduled sampling. :type mul_factor: float :param min_prob: The minimum probability for selecting predictions. :type min_prob: float :param max_prob: The maximum probability for selecting predictions. :type max_prob: float """ super(TFCRNN, self).__init__() self.dnn_output_features = cnn_channels self.rnn_hh_size = rnn_out_dim self.nb_classes = nb_classes self.batch_counter = batch_counter self.gamma_factor = gamma_factor / mul_factor self._min_prob = 1 - min_prob self.max_prob = max_prob self.iteration = 0 self.dnn = dnn.DNN(cnn_channels=cnn_channels, cnn_dropout=cnn_dropout) self.rnn_dropout = Dropout(rnn_dropout) self.rnn = GRUCell(rnn_in_dim + self.nb_classes, self.rnn_hh_size, bias=True) self.classifier = Linear(self.rnn_hh_size, self.nb_classes, bias=True)
def __init__(self, is_training, batch_size, message_size, scaler, adj_mx, **model_kwargs): # Scaler for data normalization. super(MPNNModel, self).__init__() self._scaler = scaler self._horizon = int(model_kwargs.get('horizon', 1)) self._num_nodes = int(model_kwargs.get('num_nodes', 1)) self._rnn_units = int(model_kwargs.get('rnn_units')) self._input_dim = int(model_kwargs.get('input_dim', 1)) self._output_dim = int(model_kwargs.get('output_dim', 1)) self._message_size = message_size self._batch_size = batch_size self._adj_mx = adj_mx self._cell = GRUCell((self._input_dim + self._message_size), self._rnn_units) self._M_t = Linear(self._rnn_units, self._message_size) self._R_t = Linear(self._rnn_units, self._output_dim * self._horizon)
def __init__( self, in_feats: int, out_feats: int, n_steps: int, n_etypes: int, bias: bool = True, ) -> None: """Construct a GGNN layer.""" super().__init__() self.in_feats = in_feats self.out_feats = out_feats self.n_steps = n_steps self.n_etypes = n_etypes self._linears = ModuleList( [Linear(out_feats, out_feats) for _n in range(n_etypes)] ) self._gru = GRUCell(input_size=out_feats, hidden_size=out_feats, bias=bias)
def __init__(self, input_dim, debug): """The RNN dec of the Masker. :param input_dim: The input dimensionality. :type input_dim: int :param debug: Flag to indicate debug :type debug: bool """ super(RNNDec, self).__init__() self._input_dim = input_dim self.gru_dec = GRUCell(self._input_dim, self._input_dim) self._debug = debug self._device = 'cuda' if not self._debug and torch.cuda.is_available( ) else 'cpu' self.initialize_decoder()
def __init__(self, input_dim, contxt_dim, hidden_dim, att_module, input_norm=False, hidden_norm=False, cat_contx_to_inp=True, **kwargs): """ :param input_dim: usually embeddings dim. :param hidden_dim: dimensionality of the hidden layer. :param att_module: module to attend over some external values. :param input_norm: if set to True will use layer normalization to normalize inputs to the decoder. :param hidden_norm: if set to True will use layer normalization over produced hidden states. :param cat_contx_to_inp: whether to concatenated context vector to the input word embeddings. :param kwargs: """ super(GruAttDecoder, self).__init__() self.input_dim = input_dim self.contxt_dim = contxt_dim self.hidden_dim = hidden_dim self.att = att_module self.input_norm = input_norm self.hidden_norm = hidden_norm self.cat_conxt_to_inp = cat_contx_to_inp cell_inp_dim = input_dim + contxt_dim if cat_contx_to_inp else input_dim self.gru_cell = GRUCell(cell_inp_dim, hidden_dim, **kwargs) if self.input_norm: self.inp_norm_layer = LayerNorm(input_dim + contxt_dim) if self.hidden_norm: self.hidden_norm_layer = LayerNorm(hidden_dim)
def __init__(self, input_dim: int, output_dim: int, nb_classes: int, dropout_p: float, max_out_t_steps: Optional[int] = 22, attention_dropout: Optional[float] = .25, attention_dim: Optional[Union[List[int], None]] = None) \ -> None: """Attention decoder for the baseline audio captioning method. :param input_dim: Input dimensionality for the RNN. :type input_dim: int :param output_dim: Output dimensionality for the RNN. :type output_dim: int :param nb_classes: Amount of amount classes. :type nb_classes: int :param dropout_p: RNN dropout. :type dropout_p: float :param max_out_t_steps: Maximum output time steps during inference. :type max_out_t_steps: int :param attention_dropout: Dropout for attention, defaults to .25. :type attention_dropout: float, optional :param attention_dim: Dimensionality of attention layers, defaults to None. :type attention_dim: list[int] | None, optional """ super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.nb_classes = nb_classes self.max_out_t_steps = max_out_t_steps self.dropout: Module = Dropout(p=dropout_p) self.attention: Attention = Attention(input_dim=self.input_dim, h_dim=self.output_dim, dropout_p=attention_dropout, layers_dim=attention_dim) self.gru: Module = GRUCell(self.input_dim, self.output_dim) self.classifier: Module = Linear(self.output_dim, self.nb_classes)
class TGNMemory(torch.nn.Module): r"""The Temporal Graph Network (TGN) memory model from the `"Temporal Graph Networks for Deep Learning on Dynamic Graphs" <https://arxiv.org/abs/2006.10637>`_ paper. .. note:: For an example of using TGN, see `examples/tgn.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ tgn.py>`_. Args: num_nodes (int): The number of nodes to save memories for. raw_msg_dim (int): The raw message dimensionality. memory_dim (int): The hidden memory dimensionality. time_dim (int): The time encoding dimensionality. message_module (torch.nn.Module): The message function which combines source and destination node memory embeddings, the raw message and the time encoding. aggregator_module (torch.nn.Module): The message aggregator function which aggregates messages to the same destination into a single representation. """ def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable): super().__init__() self.num_nodes = num_nodes self.raw_msg_dim = raw_msg_dim self.memory_dim = memory_dim self.time_dim = time_dim self.msg_s_module = message_module self.msg_d_module = copy.deepcopy(message_module) self.aggr_module = aggregator_module self.time_enc = TimeEncoder(time_dim) self.gru = GRUCell(message_module.out_channels, memory_dim) self.register_buffer('memory', torch.empty(num_nodes, memory_dim)) last_update = torch.empty(self.num_nodes, dtype=torch.long) self.register_buffer('last_update', last_update) self.register_buffer('__assoc__', torch.empty(num_nodes, dtype=torch.long)) self.msg_s_store = {} self.msg_d_store = {} self.reset_parameters() def reset_parameters(self): if hasattr(self.msg_s_module, 'reset_parameters'): self.msg_s_module.reset_parameters() if hasattr(self.msg_d_module, 'reset_parameters'): self.msg_d_module.reset_parameters() if hasattr(self.aggr_module, 'reset_parameters'): self.aggr_module.reset_parameters() self.time_enc.reset_parameters() self.gru.reset_parameters() self.reset_state() def reset_state(self): """Resets the memory to its initial state.""" zeros(self.memory) zeros(self.last_update) self.__reset_message_store__() def detach(self): """Detachs the memory from gradient computation.""" self.memory.detach_() def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]: """Returns, for all nodes :obj:`n_id`, their current memory and their last updated timestamp.""" if self.training: memory, last_update = self.__get_updated_memory__(n_id) else: memory, last_update = self.memory[n_id], self.last_update[n_id] return memory, last_update def update_state(self, src, dst, t, raw_msg): """Updates the memory with newly encountered interactions :obj:`(src, dst, t, raw_msg)`.""" n_id = torch.cat([src, dst]).unique() if self.training: self.__update_memory__(n_id) self.__update_msg_store__(src, dst, t, raw_msg, self.msg_s_store) self.__update_msg_store__(dst, src, t, raw_msg, self.msg_d_store) else: self.__update_msg_store__(src, dst, t, raw_msg, self.msg_s_store) self.__update_msg_store__(dst, src, t, raw_msg, self.msg_d_store) self.__update_memory__(n_id) def __reset_message_store__(self): i = self.memory.new_empty((0, ), dtype=torch.long) msg = self.memory.new_empty((0, self.raw_msg_dim)) # Message store format: (src, dst, t, msg) self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)} self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)} def __update_memory__(self, n_id): memory, last_update = self.__get_updated_memory__(n_id) self.memory[n_id] = memory self.last_update[n_id] = last_update def __get_updated_memory__(self, n_id): self.__assoc__[n_id] = torch.arange(n_id.size(0), device=n_id.device) # Compute messages (src -> dst). msg_s, t_s, src_s, dst_s = self.__compute_msg__( n_id, self.msg_s_store, self.msg_s_module) # Compute messages (dst -> src). msg_d, t_d, src_d, dst_d = self.__compute_msg__( n_id, self.msg_d_store, self.msg_d_module) # Aggregate messages. idx = torch.cat([src_s, src_d], dim=0) msg = torch.cat([msg_s, msg_d], dim=0) t = torch.cat([t_s, t_d], dim=0) aggr = self.aggr_module(msg, self.__assoc__[idx], t, n_id.size(0)) # Get local copy of updated memory. memory = self.gru(aggr, self.memory[n_id]) # Get local copy of updated `last_update`. dim_size = self.last_update.size(0) last_update = scatter_max(t, idx, dim=0, dim_size=dim_size)[0][n_id] return memory, last_update def __update_msg_store__(self, src, dst, t, raw_msg, msg_store): n_id, perm = src.sort() n_id, count = n_id.unique_consecutive(return_counts=True) for i, idx in zip(n_id.tolist(), perm.split(count.tolist())): msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx]) def __compute_msg__(self, n_id, msg_store, msg_module): data = [msg_store[i] for i in n_id.tolist()] src, dst, t, raw_msg = list(zip(*data)) src = torch.cat(src, dim=0) dst = torch.cat(dst, dim=0) t = torch.cat(t, dim=0) raw_msg = torch.cat(raw_msg, dim=0) t_rel = t - self.last_update[src] t_enc = self.time_enc(t_rel.to(raw_msg.dtype)) msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc) return msg, t, src, dst def train(self, mode: bool = True): """Sets the module in training mode.""" if self.training and not mode: # Flush message store to memory in case we just entered eval mode. self.__update_memory__( torch.arange(self.num_nodes, device=self.memory.device)) self.__reset_message_store__() super().train(mode)
def __init__(self, c_dim, h_dim): super(GRUAggregation, self).__init__() self.gru = GRUCell(c_dim, h_dim)
class PlainRNN(BaseModel): """ single_loss: single_loss= True, means the loss is only caculated in the last time step. """ def __init__(self, inp_size, hid_size, out_size, rnn_type="raw_rnn", single_loss=True): super().__init__() allow_rnn_types = ["raw_rnn","lstm","gru"] assert rnn_type in allow_rnn_types self.rnn_type = rnn_type self.single_loss = single_loss self.inp_size = inp_size self.hid_size = hid_size self.out_size = out_size if rnn_type == "raw_rnn": self.lstm = RNNCell(inp_size, hid_size) if rnn_type == "lstm": self.lstm = LSTMCell(inp_size, hid_size) if rnn_type == "gru": self.lstm = GRUCell(inp_size, hid_size) self.fc1 = nn.Linear(hid_size, out_size) self.criterion = nn.CrossEntropyLoss() def init_weights(self): self.lstm.reset_parameters() self.fc1.reset_parameters() def init_states(self,batch_size): if self.rnn_type == "lstm": self.h = torch.zeros(batch_size, self.hid_size).to(device) self.c = torch.zeros(batch_size, self.hid_size).to(device) if self.rnn_type == "raw_rnn" or self.rnn_type == "gru": self.h = torch.zeros(batch_size, self.hid_size).to(device) def forward_train(self, x): assert len(x) == 2 inp_x, inp_y = x inp_x = inp_x.to(device) inp_y = inp_y.to(device) batch, T, _ = inp_x.shape self.init_states(batch) dm_states = [] loss = 0 rr = torch.zeros((batch,self.out_size)).to(device) if self.rnn_type == "lstm": for i in range(T): y, (self.h,self.c) = self.lstm(inp_x[:,i],(self.h,self.c)) output = self.fc1(y) rr.copy_(output) dm_states.append(rr.cpu().detach().numpy()) if not self.single_loss: loss += self.criterion(output, inp_y.reshape(-1)) if self.single_loss: loss += self.criterion(output, inp_y.reshape(-1)) if self.rnn_type == "raw_rnn" or self.rnn_type == "gru": for i in range(T): self.h = self.lstm(inp_x[:,i],self.h) output = self.fc1(self.h) rr.copy_(output) dm_states.append(rr.cpu().detach().numpy()) if not self.single_loss: loss += self.criterion(output, inp_y.reshape(-1)) if self.single_loss: loss += self.criterion(output, inp_y.reshape(-1)) print("loss is ",loss.cpu().item()) outputs = dict( loss = loss, outputs = dm_states ) return outputs def forward_test(self,x): # x, new_state = self.lstm(x, state) # x = self.fc1(x assert len(x) == 2 inp_x, inp_y = x inp_x = inp_x.to(device) batch, T, _ = inp_x.shape self.init_states(batch) dm_states = [] rr = torch.zeros((batch,self.out_size)).to(device) if self.rnn_type == "lstm": for i in range(T): y, (self.h,self.c) = self.lstm(inp_x[:,i],(self.h,self.c)) output = self.fc1(y) rr.copy_(output) dm_states.append(rr.cpu().detach().numpy()) if self.rnn_type == "raw_rnn" or self.rnn_type == "gru": for i in range(T): self.h = self.lstm(inp_x[:,i],self.h) output = self.fc1(self.h) rr.copy_(output) dm_states.append(rr.cpu().detach().numpy()) if self.single_loss: dm_states = (np.array(dm_states)[-1]).reshape(1,batch,-1) else: # dm_states = (np.array(dm_states)).reshape(T,batch,-1) dm_states = np.array(dm_states).reshape(T*batch,-1) dm_states_ = np.zeros_like(dm_states) index = np.argmax(dm_states,axis=1) dm_states_[range(T*batch),index] = 1 dm_states = dm_states_.reshape(T,batch,-1).mean(axis=0).reshape(1,batch,-1) inp_y = inp_y.view(-1).cpu().numpy() outputs = dict( outputs = dm_states, labels = inp_y ) return outputs
class AttentiveFP(torch.nn.Module): r"""The Attentive FP model for molecular representation learning from the `"Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism" <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on graph attention mechanisms. Args: in_channels (int): Size of each input sample. hidden_channels (int): Hidden node feature dimensionality. out_channels (int): Size of each output sample. edge_dim (int): Edge feature dimensionality. num_layers (int): Number of GNN layers. num_timesteps (int): Number of iterative refinement steps for global readout. dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, edge_dim: int, num_layers: int, num_timesteps: int, dropout: float = 0.0): super().__init__() self.num_layers = num_layers self.num_timesteps = num_timesteps self.dropout = dropout self.lin1 = Linear(in_channels, hidden_channels) conv = GATEConv(hidden_channels, hidden_channels, edge_dim, dropout) gru = GRUCell(hidden_channels, hidden_channels) self.atom_convs = torch.nn.ModuleList([conv]) self.atom_grus = torch.nn.ModuleList([gru]) for _ in range(num_layers - 1): conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(hidden_channels, hidden_channels)) self.mol_conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.mol_gru = GRUCell(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.reset_parameters() def reset_parameters(self): self.lin1.reset_parameters() for conv, gru in zip(self.atom_convs, self.atom_grus): conv.reset_parameters() gru.reset_parameters() self.mol_conv.reset_parameters() self.mol_gru.reset_parameters() self.lin2.reset_parameters() def forward(self, x, edge_index, edge_attr, batch): """""" # Atom Embedding: x = F.leaky_relu_(self.lin1(x)) h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr)) h = F.dropout(h, p=self.dropout, training=self.training) x = self.atom_grus[0](h, x).relu_() for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]): h = F.elu_(conv(x, edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) x = gru(h, x).relu_() # Molecule Embedding: row = torch.arange(batch.size(0), device=batch.device) edge_index = torch.stack([row, batch], dim=0) out = global_add_pool(x, batch).relu_() for t in range(self.num_timesteps): h = F.elu_(self.mol_conv((x, out), edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) out = self.mol_gru(h, out).relu_() # Predictor: out = F.dropout(out, p=self.dropout, training=self.training) return self.lin2(out)