def forward(self,
                input,
                lengths=None,
                hidden=None,
                dump_layers=False,
                intervention=None):
        """ See :obj:`onmt.modules.EncoderBase.forward()`"""
        self._check_args(input, lengths, hidden)

        emb = self.embeddings(input)
        # s_len, batch, emb_dim = emb.size()

        emb = emb.transpose(0, 1).contiguous()
        emb_reshape = emb.view(emb.size(0) * emb.size(1), -1)
        emb_remap = self.linear(emb_reshape)
        emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1)
        emb_remap = shape_transform(emb_remap)
        if dump_layers:
            dumped_layers, out = self.cnn(emb_remap, True, intervention)
        else:
            out = self.cnn(emb_remap, False, intervention)

        if dump_layers:
            return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \
                dumped_layers, \
                out.squeeze(3).transpose(0, 1).contiguous()
        else:
            return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \
                out.squeeze(3).transpose(0, 1).contiguous()
Example #2
0
    def forward(self, tgt, memory_bank, memory_lengths=None, step=None):
        """ See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
        # NOTE: memory_lengths is only here for compatibility reasons
        #       with onmt.modules.RNNDecoderBase.forward()

        if self.state["previous_input"] is not None:
            tgt = torch.cat([self.state["previous_input"], tgt], 0)

        # Initialize return variables.
        dec_outs = []
        attns = {"std": []}
        assert not self._copy, "Copy mechanism not yet tested in conv2conv"
        if self._copy:
            attns["copy"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        tgt_emb = emb.transpose(0, 1).contiguous()
        # The output of CNNEncoder.
        src_memory_bank_t = memory_bank.transpose(0, 1).contiguous()
        # The combination of output of CNNEncoder and source embeddings.
        src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous()

        # Run the forward pass of the CNNDecoder.
        emb_reshape = tgt_emb.contiguous().view(
            tgt_emb.size(0) * tgt_emb.size(1), -1)
        linear_out = self.linear(emb_reshape)
        x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
        x = shape_transform(x)

        pad = torch.zeros(x.size(0), x.size(1),
                          self.cnn_kernel_width - 1, 1)

        pad = pad.type_as(x)
        base_target_emb = x

        for conv, attention in zip(self.conv_layers, self.attn_layers):
            new_target_input = torch.cat([pad, x], 2)
            out = conv(new_target_input)
            c, attn = attention(base_target_emb, out,
                                src_memory_bank_t, src_memory_bank_c)
            x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
        output = x.squeeze(3).transpose(1, 2)

        # Process the result and update the attentions.
        dec_outs = output.transpose(0, 1).contiguous()
        if self.state["previous_input"] is not None:
            dec_outs = dec_outs[self.state["previous_input"].size(0):]
            attn = attn[:, self.state["previous_input"].size(0):].squeeze()
            attn = torch.stack([attn])
        attns["std"] = attn
        if self._copy:
            attns["copy"] = attn

        # Update the state.
        self.update_state(tgt)
        # TODO change the way attns is returned dict => list or tuple (onnx)
        return dec_outs, attns
Example #3
0
    def forward(self, tgt, memory_bank, step=None, **kwargs):
        """ See :obj:`onmt.modules.RNNDecoderBase.forward()`"""

        if self.state["previous_input"] is not None:
            tgt = torch.cat([self.state["previous_input"], tgt], 0)

        dec_outs = []
        attns = {"std": []}
        if self.copy_attn is not None:
            attns["copy"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        tgt_emb = emb.transpose(0, 1).contiguous()
        # The output of CNNEncoder.
        src_memory_bank_t = memory_bank.transpose(0, 1).contiguous()
        # The combination of output of CNNEncoder and source embeddings.
        src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous()

        emb_reshape = tgt_emb.contiguous().view(
            tgt_emb.size(0) * tgt_emb.size(1), -1)
        linear_out = self.linear(emb_reshape)
        x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
        x = shape_transform(x)

        pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)

        pad = pad.type_as(x)
        base_target_emb = x

        for conv, attention in zip(self.conv_layers, self.attn_layers):
            new_target_input = torch.cat([pad, x], 2)
            out = conv(new_target_input)
            c, attn = attention(base_target_emb, out,
                                src_memory_bank_t, src_memory_bank_c)
            x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
        output = x.squeeze(3).transpose(1, 2)

        # Process the result and update the attentions.
        dec_outs = output.transpose(0, 1).contiguous()
        if self.state["previous_input"] is not None:
            dec_outs = dec_outs[self.state["previous_input"].size(0):]
            attn = attn[:, self.state["previous_input"].size(0):].squeeze()
            attn = torch.stack([attn])
        attns["std"] = attn
        if self.copy_attn is not None:
            attns["copy"] = attn

        # Update the state.
        self.state["previous_input"] = tgt
        # TODO change the way attns is returned dict => list or tuple (onnx)
        return dec_outs, attns
Example #4
0
    def forward(self, input, lengths=None, hidden=None):
        """ See :obj:`onmt.modules.EncoderBase.forward()`"""
        self._check_args(input, lengths, hidden)

        emb = self.embeddings(input)

        emb = emb.transpose(0, 1).contiguous()
        emb_reshape = emb.view(emb.size(0) * emb.size(1), -1)
        emb_remap = self.linear(emb_reshape)
        emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1)
        emb_remap = shape_transform(emb_remap)
        out = self.cnn(emb_remap)

        return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \
            out.squeeze(3).transpose(0, 1).contiguous(), lengths
Example #5
0
    def forward(self, input, lengths=None, hidden=None):
        """See :class:`onmt.modules.EncoderBase.forward()`"""
        self._check_args(input, lengths, hidden)

        emb = self.embeddings(input)
        # s_len, batch, emb_dim = emb.size()

        emb = emb.transpose(0, 1).contiguous()
        emb_reshape = emb.view(emb.size(0) * emb.size(1), -1)
        emb_remap = self.linear(emb_reshape)
        emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1)
        emb_remap = shape_transform(emb_remap)
        out = self.cnn(emb_remap)

        return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \
            out.squeeze(3).transpose(0, 1).contiguous(), lengths
Example #6
0
    def cnnforward(self, input, lengths=None, hidden=None):
        """See :class:`onmt.modules.EncoderBase.forward()`"""
        import ipdb
        ipdb.set_trace()
        self._check_args(input, lengths, hidden)

        emb = self.embeddings(input)  #[src_len, bsz, emb_dim]
        emb = emb.transpose(0, 1).contiguous()  #[bsz, src_len, emb_dim]
        emb_reshape = emb.view(emb.size(0) * emb.size(1),
                               -1)  #[(bsz*src_len), emb_dim]
        emb_remap = self.linear(emb_reshape)  #[(bsz*src_len), emb_dim]
        emb_remap = emb_remap.view(emb.size(0), emb.size(1),
                                   -1)  #[bsz, src_len, emb_dim]
        emb_remap = shape_transform(emb_remap)  #[bsz, emb_dim, src_len, 1]
        out = self.cnn(emb_remap)  #[bsz, emb_dim, src_len, 1]

        return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \
            out.squeeze(3).transpose(0, 1).contiguous(), lengths    #[emb_dim,bsz,src_len],[emb_dim,bsz,src_len], [bsz]
Example #7
0
    def forward(self, input, lengths=None, hidden=None):
        """ See :obj:`onmt.modules.EncoderBase.forward()`"""
        self._check_args(input, lengths, hidden)
        if self.embeddings is not None:
            emb = self.embeddings(input)
        else:
            emb = input.transpose(0, 1)  #torch.Size([26, 64, 500])
        # s_len, batch, emb_dim = emb.size()

        emb = emb.transpose(0, 1).contiguous()
        emb_reshape = emb.view(emb.size(0) * emb.size(1), -1)
        emb_remap = self.linear(emb_reshape)
        emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1)
        emb_remap = shape_transform(emb_remap)
        out = self.cnn(emb_remap)  #torch.Size([64, 512, 26, 1])

        return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \
            out.squeeze(3).transpose(0, 1).contiguous(), lengths