def __init__(self, n_words, input_size, hidden_size, bridge_type="mlp", dropout_rate=0.0): super(Decoder, self).__init__() self.bridge_type = bridge_type self.hidden_size = hidden_size self.context_size = hidden_size * 2 self.embedding = Embeddings(num_embeddings=n_words, embedding_dim=input_size, dropout=0.0, add_position_embedding=False) self.cgru_cell = CGRUCell(input_size=input_size, hidden_size=hidden_size) self.linear_input = nn.Linear(in_features=input_size, out_features=input_size) self.linear_hidden = nn.Linear(in_features=hidden_size, out_features=input_size) self.linear_ctx = nn.Linear(in_features=hidden_size * 2, out_features=input_size) self.dropout = nn.Dropout(dropout_rate) self._reset_parameters() self._build_bridge()
class Decoder(nn.Module): def __init__(self, n_words, input_size, hidden_size, bridge_type="mlp", dropout_rate=0.0): super(Decoder, self).__init__() self.bridge_type = bridge_type self.hidden_size = hidden_size self.context_size = hidden_size * 2 self.embedding = Embeddings(num_embeddings=n_words, embedding_dim=input_size, dropout=0.0, add_position_embedding=False) self.cgru_cell = CGRUCell(input_size=input_size, hidden_size=hidden_size) self.linear_input = nn.Linear(in_features=input_size, out_features=input_size) self.linear_hidden = nn.Linear(in_features=hidden_size, out_features=input_size) self.linear_ctx = nn.Linear(in_features=hidden_size * 2, out_features=input_size) self.dropout = nn.Dropout(dropout_rate) self._reset_parameters() self._build_bridge() def _reset_parameters(self): my_init.default_init(self.linear_input.weight) my_init.default_init(self.linear_hidden.weight) my_init.default_init(self.linear_ctx.weight) def _build_bridge(self): if self.bridge_type == "mlp": self.linear_bridge = nn.Linear(in_features=self.context_size, out_features=self.hidden_size) my_init.default_init(self.linear_bridge.weight) elif self.bridge_type == "zero": pass else: raise ValueError("Unknown bridge type {0}".format(self.bridge_type)) def init_decoder(self, context, mask): # Generate init hidden if self.bridge_type == "mlp": no_pad_mask = 1.0 - mask.float() ctx_mean = (context * no_pad_mask.unsqueeze(2)).sum(1) / no_pad_mask.unsqueeze(2).sum(1) dec_init = F.tanh(self.linear_bridge(ctx_mean)) elif self.bridge_type == "zero": batch_size = context.size(0) dec_init = context.new(batch_size, self.hidden_size).zero_() else: raise ValueError("Unknown bridge type {0}".format(self.bridge_type)) dec_cache = self.cgru_cell.compute_cache(context) return dec_init, dec_cache def forward(self, y, context, context_mask, hidden, one_step=False, cache=None): emb = self.embedding(y) # [seq_len, batch_size, dim] if one_step: (out, attn), hidden = self.cgru_cell(emb, hidden, context, context_mask, cache) else: # emb: [seq_len, batch_size, dim] out = [] attn = [] for emb_t in torch.split(emb, split_size_or_sections=1, dim=1): (out_t, attn_t), hidden = self.cgru_cell(emb_t.squeeze(1), hidden, context, context_mask, cache) out += [out_t] attn += [attn_t] out = torch.stack(out).transpose(1, 0).contiguous() attn = torch.stack(attn).transpose(1, 0).contiguous() logits = self.linear_input(emb) + self.linear_hidden(out) + self.linear_ctx(attn) logits = F.tanh(logits) logits = self.dropout(logits) # [seq_len, batch_size, dim] return logits, hidden