def __init__(self, input_size, hidden_size, num_layers, bidirectional, rnn_type="GRU", dropout=0., stateful=False, batch_first=True): super().__init__(stateful=stateful) rnn_cls = Seq2SeqEncoder.by_name(rnn_type.lower()) self.rnn_list = torch.nn.ModuleList([ rnn_cls(input_size=input_size, hidden_size=hidden_size, bidirectional=bidirectional, dropout=dropout) ]) for _ in range(num_layers - 1): self.rnn_list.append( rnn_cls(input_size=hidden_size * 2, hidden_size=hidden_size, bidirectional=bidirectional, dropout=dropout)) self.dropout = RNNDropout(dropout, batch_first=batch_first)
def __init__(self, input_size, hidden_size, num_layers, bidirectional=True, rnn_type="GRU", stateful=False, batch_first=True): super().__init__(stateful=stateful) self.input_dim = input_size self.output_dim = num_layers * (hidden_size * 2 if bidirectional else hidden_size) self.bidirectional = bidirectional rnn_cls = Seq2SeqEncoder.by_name(rnn_type.lower()) self.rnn_list = torch.nn.ModuleList([ rnn_cls(input_size=input_size, hidden_size=hidden_size, bidirectional=bidirectional) ]) for _ in range(num_layers - 1): self.rnn_list.append( rnn_cls(input_size=hidden_size * 2, hidden_size=hidden_size, bidirectional=bidirectional))
def __init__(self, memory_size, input_size, hidden_size, attention_size, bidirectional, dropout, attention_factory=StaticDotAttention, rnn_type="GRU", batch_first=True): super().__init__() rnn_fn = Seq2SeqEncoder.by_name(rnn_type.lower()) self.attention = attention_factory(memory_size, input_size, attention_size, dropout=dropout, batch_first=batch_first) self.gate = nn.Sequential( Gate(input_size + memory_size, dropout=dropout), RNNDropout(dropout, batch_first=batch_first)) self.encoder = rnn_fn(input_size=memory_size + input_size, hidden_size=hidden_size, bidirectional=bidirectional, batch_first=batch_first)
def __init__(self, input_dim: int = None, context_vector_dim: int = None) -> None: super().__init__() context_vector_dim = context_vector_dim or input_dim #self.alpha = torch.nn.Parameter(torch.randn(1)) self._mlp = torch.nn.Linear(input_dim, context_vector_dim, bias=True) self._context_dot_product = torch.nn.Linear(context_vector_dim, 1, bias=False) self.vec_dim = self._mlp.weight.size(1) self._encoder = Seq2SeqEncoder.by_name('gru')(input_size=input_dim, hidden_size=input_dim // 2, bidirectional=True)
def get_rnns(rnn_type: str, input_size: int, hidden_size: int, num_layers: int, bidirectional: bool): """ Creates and returns an equivalent AllenNLP ``Seq2SeqEncoder`` and ``RNN`` RNNs. """ assert num_layers in [1, 2] assert rnn_type in ['gru', 'lstm'] seq2seq_encoder = Seq2SeqEncoder.by_name(rnn_type)( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional) if rnn_type == 'gru': rnn = GRU(input_size, hidden_size, num_layers, bidirectional) else: rnn = LSTM(input_size, hidden_size, num_layers, bidirectional) rnn.rnn.weight_ih_l0[:] = seq2seq_encoder._module.weight_ih_l0[:] rnn.rnn.weight_hh_l0[:] = seq2seq_encoder._module.weight_hh_l0[:] rnn.rnn.bias_ih_l0[:] = seq2seq_encoder._module.bias_ih_l0[:] rnn.rnn.bias_hh_l0[:] = seq2seq_encoder._module.bias_hh_l0[:] if bidirectional: rnn.rnn.weight_ih_l0_reverse[:] = seq2seq_encoder._module.weight_ih_l0_reverse[:] rnn.rnn.weight_hh_l0_reverse[:] = seq2seq_encoder._module.weight_hh_l0_reverse[:] rnn.rnn.bias_ih_l0_reverse[:] = seq2seq_encoder._module.bias_ih_l0_reverse[:] rnn.rnn.bias_hh_l0_reverse[:] = seq2seq_encoder._module.bias_hh_l0_reverse[:] if num_layers == 2: rnn.rnn.weight_ih_l1[:] = seq2seq_encoder._module.weight_ih_l1[:] rnn.rnn.weight_hh_l1[:] = seq2seq_encoder._module.weight_hh_l1[:] rnn.rnn.bias_ih_l1[:] = seq2seq_encoder._module.bias_ih_l1[:] rnn.rnn.bias_hh_l1[:] = seq2seq_encoder._module.bias_hh_l1[:] if bidirectional: rnn.rnn.weight_ih_l1_reverse[:] = seq2seq_encoder._module.weight_ih_l1_reverse[:] rnn.rnn.weight_hh_l1_reverse[:] = seq2seq_encoder._module.weight_hh_l1_reverse[:] rnn.rnn.bias_ih_l1_reverse[:] = seq2seq_encoder._module.bias_ih_l1_reverse[:] rnn.rnn.bias_hh_l1_reverse[:] = seq2seq_encoder._module.bias_hh_l1_reverse[:] return seq2seq_encoder, rnn