def forward(self, input, lengths=None, hidden=None, dump_layers=False, intervention=None): """ See :obj:`onmt.modules.EncoderBase.forward()`""" self._check_args(input, lengths, hidden) emb = self.embeddings(input) # s_len, batch, emb_dim = emb.size() emb = emb.transpose(0, 1).contiguous() emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) emb_remap = self.linear(emb_reshape) emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) emb_remap = shape_transform(emb_remap) if dump_layers: dumped_layers, out = self.cnn(emb_remap, True, intervention) else: out = self.cnn(emb_remap, False, intervention) if dump_layers: return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ dumped_layers, \ out.squeeze(3).transpose(0, 1).contiguous() else: return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ out.squeeze(3).transpose(0, 1).contiguous()
def forward(self, tgt, memory_bank, memory_lengths=None, step=None): """ See :obj:`onmt.modules.RNNDecoderBase.forward()`""" # NOTE: memory_lengths is only here for compatibility reasons # with onmt.modules.RNNDecoderBase.forward() if self.state["previous_input"] is not None: tgt = torch.cat([self.state["previous_input"], tgt], 0) # Initialize return variables. dec_outs = [] attns = {"std": []} assert not self._copy, "Copy mechanism not yet tested in conv2conv" if self._copy: attns["copy"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim tgt_emb = emb.transpose(0, 1).contiguous() # The output of CNNEncoder. src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() # The combination of output of CNNEncoder and source embeddings. src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous() # Run the forward pass of the CNNDecoder. emb_reshape = tgt_emb.contiguous().view( tgt_emb.size(0) * tgt_emb.size(1), -1) linear_out = self.linear(emb_reshape) x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) x = shape_transform(x) pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1) pad = pad.type_as(x) base_target_emb = x for conv, attention in zip(self.conv_layers, self.attn_layers): new_target_input = torch.cat([pad, x], 2) out = conv(new_target_input) c, attn = attention(base_target_emb, out, src_memory_bank_t, src_memory_bank_c) x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT output = x.squeeze(3).transpose(1, 2) # Process the result and update the attentions. dec_outs = output.transpose(0, 1).contiguous() if self.state["previous_input"] is not None: dec_outs = dec_outs[self.state["previous_input"].size(0):] attn = attn[:, self.state["previous_input"].size(0):].squeeze() attn = torch.stack([attn]) attns["std"] = attn if self._copy: attns["copy"] = attn # Update the state. self.update_state(tgt) # TODO change the way attns is returned dict => list or tuple (onnx) return dec_outs, attns
def forward(self, tgt, memory_bank, step=None, **kwargs): """ See :obj:`onmt.modules.RNNDecoderBase.forward()`""" if self.state["previous_input"] is not None: tgt = torch.cat([self.state["previous_input"], tgt], 0) dec_outs = [] attns = {"std": []} if self.copy_attn is not None: attns["copy"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim tgt_emb = emb.transpose(0, 1).contiguous() # The output of CNNEncoder. src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() # The combination of output of CNNEncoder and source embeddings. src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous() emb_reshape = tgt_emb.contiguous().view( tgt_emb.size(0) * tgt_emb.size(1), -1) linear_out = self.linear(emb_reshape) x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) x = shape_transform(x) pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1) pad = pad.type_as(x) base_target_emb = x for conv, attention in zip(self.conv_layers, self.attn_layers): new_target_input = torch.cat([pad, x], 2) out = conv(new_target_input) c, attn = attention(base_target_emb, out, src_memory_bank_t, src_memory_bank_c) x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT output = x.squeeze(3).transpose(1, 2) # Process the result and update the attentions. dec_outs = output.transpose(0, 1).contiguous() if self.state["previous_input"] is not None: dec_outs = dec_outs[self.state["previous_input"].size(0):] attn = attn[:, self.state["previous_input"].size(0):].squeeze() attn = torch.stack([attn]) attns["std"] = attn if self.copy_attn is not None: attns["copy"] = attn # Update the state. self.state["previous_input"] = tgt # TODO change the way attns is returned dict => list or tuple (onnx) return dec_outs, attns
def forward(self, input, lengths=None, hidden=None): """ See :obj:`onmt.modules.EncoderBase.forward()`""" self._check_args(input, lengths, hidden) emb = self.embeddings(input) emb = emb.transpose(0, 1).contiguous() emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) emb_remap = self.linear(emb_reshape) emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) emb_remap = shape_transform(emb_remap) out = self.cnn(emb_remap) return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ out.squeeze(3).transpose(0, 1).contiguous(), lengths
def forward(self, input, lengths=None, hidden=None): """See :class:`onmt.modules.EncoderBase.forward()`""" self._check_args(input, lengths, hidden) emb = self.embeddings(input) # s_len, batch, emb_dim = emb.size() emb = emb.transpose(0, 1).contiguous() emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) emb_remap = self.linear(emb_reshape) emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) emb_remap = shape_transform(emb_remap) out = self.cnn(emb_remap) return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ out.squeeze(3).transpose(0, 1).contiguous(), lengths
def cnnforward(self, input, lengths=None, hidden=None): """See :class:`onmt.modules.EncoderBase.forward()`""" import ipdb ipdb.set_trace() self._check_args(input, lengths, hidden) emb = self.embeddings(input) #[src_len, bsz, emb_dim] emb = emb.transpose(0, 1).contiguous() #[bsz, src_len, emb_dim] emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) #[(bsz*src_len), emb_dim] emb_remap = self.linear(emb_reshape) #[(bsz*src_len), emb_dim] emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) #[bsz, src_len, emb_dim] emb_remap = shape_transform(emb_remap) #[bsz, emb_dim, src_len, 1] out = self.cnn(emb_remap) #[bsz, emb_dim, src_len, 1] return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ out.squeeze(3).transpose(0, 1).contiguous(), lengths #[emb_dim,bsz,src_len],[emb_dim,bsz,src_len], [bsz]
def forward(self, input, lengths=None, hidden=None): """ See :obj:`onmt.modules.EncoderBase.forward()`""" self._check_args(input, lengths, hidden) if self.embeddings is not None: emb = self.embeddings(input) else: emb = input.transpose(0, 1) #torch.Size([26, 64, 500]) # s_len, batch, emb_dim = emb.size() emb = emb.transpose(0, 1).contiguous() emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) emb_remap = self.linear(emb_reshape) emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) emb_remap = shape_transform(emb_remap) out = self.cnn(emb_remap) #torch.Size([64, 512, 26, 1]) return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ out.squeeze(3).transpose(0, 1).contiguous(), lengths