Пример #1
0
    def process_embedding(self, input, input_lang=None):

        # if self.switchout == 0:
        #     input_ = input
        # if self.switchout > 0 and self.training:
        #     vocab_size = self.word_lut.weight.size(0)
        #     input_ = switchout(input, vocab_size, self.switchout)
        # else:
        input_ = input

        emb = embedded_dropout(
            self.word_lut,
            input_,
            dropout=self.word_dropout if self.training else 0)
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)

        if self.use_language_embedding:
            lang_emb = self.language_embeddings(input_lang)  # B x H or 1 x H
            if self.language_embedding_type == 'sum':
                emb = emb + lang_emb
            elif self.language_embedding_type == 'concat':
                # replace the bos embedding with the language
                bos_emb = lang_emb.expand_as(emb[:, 0, :])
                emb[:, 0, :] = bos_emb

                lang_emb = lang_emb.unsqueeze(1).expand_as(emb)
                concat_emb = torch.cat([emb, lang_emb], dim=-1)
                emb = torch.relu(self.projector(concat_emb))
            else:
                raise NotImplementedError
        return emb
Пример #2
0
    def forward(self, input, **kwargs):
        """
        Inputs Shapes:
            input: (Variable)  len_tgt x batch_size
        Outputs Shapes:
            out: len_tgt x batch_size x  d_model
        """

        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)

        emb = self.preprocess_layer(emb)

        if self.h is None:
            lstm_mem = None
        else:
            lstm_mem = (self.h.detach(), self.c.detach())

        output, (h, c) = self.rnn(emb, lstm_mem)

        output = self.postprocess_layer(output)

        output_dict = defaultdict(lambda: None)
        output_dict['hidden'] = output
        output_dict['lstm_mem'] = (h, c)

        self.h = h
        self.c = c

        return output_dict
Пример #3
0
    def forward(self, input):
        """
        Inputs Shapes: 
            input: batch_size x len_src (wanna tranpose)
        
        Outputs Shapes:
            out: batch_size x len_src x d_model
            mask_src 
            
        """
        """ Embedding: batch_size x len_src x d_model """
        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)
        """ Scale the emb by sqrt(d_model) """

        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)
        if isinstance(emb, tuple):
            emb = emb[0]
        emb = self.preprocess_layer(emb)

        mask_src = input.data.eq(onmt.constants.PAD).unsqueeze(
            1)  # batch_size x len_src x 1 for broadcasting

        pad_mask = torch.autograd.Variable(input.data.ne(
            onmt.constants.PAD))  # batch_size x len_src
        #~ pad_mask = None

        context = emb.contiguous()

        memory_bank = None

        for i, layer in enumerate(self.layer_modules):
            if len(self.layer_modules
                   ) - i <= onmt.constants.checkpointing and self.training:
                context, memory_bank = checkpoint(custom_layer(layer), context,
                                                  memory_bank, mask_src,
                                                  pad_mask)

                #~ print(type(context))
            else:
                context, memory_bank = layer(
                    context, memory_bank, mask_src,
                    pad_mask)  # batch_size x len_src x d_model

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        context = self.postprocess_layer(context)

        # make a huge memory bank on the encoder side
        memory_bank = torch.cat([memory_bank, context.unsqueeze(0)], dim=0)

        return memory_bank, mask_src
Пример #4
0
    def encode(self, input, decoder_state, input_pos=None, input_lang=None):

        buffers = decoder_state.attention_buffers
        src_lang = input_lang
        input = input.transpose(0, 1)
        # Embedding stage (and scale the embedding)
        src_emb = embedded_dropout(self.src_embedding, input, dropout=self.word_dropout if self.training else 0) \
                  * math.sqrt(self.model_size)

        if self.use_language_embedding:
            if self.language_embedding_type in ["sum", "all_sum"]:
                src_lang_emb = self.language_embeddings(src_lang)
                src_emb += src_lang_emb

        emb = src_emb
        src_len = input.size(0)
        bsz = input.size(1)
        mask_src_src = input.eq(onmt.constants.PAD).expand(src_len, src_len, bsz)

        buffer = buffers[0] if 0 in buffers else None
        if buffer is not None:
            mem_len = buffer['k'].size(0)
        else:
            mem_len = 0

        if mem_len > 0:
            # print(mask_src_src.size())
            past_mask = input.new_zeros(src_len, mem_len).bool().unsqueeze(-1).expand(src_len, mem_len, bsz)
            mask_src_src = torch.cat([past_mask, mask_src_src], dim=1)

        mask_src = mask_src_src
        attn_mask = mask_src.bool()  # L x L x batch_size

        output = emb

        klen = src_len + mem_len
        pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype)

        pos_emb = self.positional_encoder(pos)

        # FORWARD PASS
        coverage = None
        for i, layer in enumerate(self.layer_modules):
            # context and context_mask are None
            buffer = buffers[i] if i in buffers else None
            # if i == 0 and buffer is not None:
            #     key = next(iter(buffer))
            #     print(buffer[key].size())
            # output, coverage, buffer = layer.step(output, None, attn_mask, None, buffer)
            output, coverage, buffer = layer(output, None, pos_emb, attn_mask, None,
                                             incremental=True, incremental_cache=buffer)
            decoder_state.update_attention_buffer(buffer, i)

        # Final normalization
        output = self.postprocess_layer(output)

        return output, decoder_state
    def encode(self, input, decoder_state, input_pos=None, input_lang=None):

        buffers = decoder_state.attention_buffers
        src_lang = input_lang
        input = input.transpose(0, 1)
        # Embedding stage (and scale the embedding)
        src_emb = embedded_dropout(self.src_embedding, input, dropout=self.word_dropout if self.training else 0) \
                  * math.sqrt(self.model_size)

        if self.use_language_embedding:
            if self.language_embedding_type in ["sum", "all_sum"]:
                src_lang_emb = self.language_embeddings(src_lang)
                src_emb += src_lang_emb

        emb = src_emb
        src_len = input.size(0)
        bsz = input.size(1)
        mask_src_src = input.eq(onmt.constants.PAD).byte()  # B x 1 x src_len
        mask_src = mask_src_src.unsqueeze(0)

        attn_mask = mask_src.bool()  # L x L x batch_size

        output = emb

        # Applying dropout and tranpose to T x B x H
        output = self.preprocess_layer(output)

        klen = src_len
        pos = torch.arange(klen - 1,
                           -klen,
                           -1.0,
                           device=emb.device,
                           dtype=emb.dtype)

        pos_emb = self.positional_encoder(pos)

        # FORWARD PASS
        coverage = None
        for i, layer in enumerate(self.layer_modules):
            # context and context_mask are None
            buffer = buffers[i] if i in buffers else None
            # output, coverage, buffer = layer.step(output, None, attn_mask, None, buffer)
            output, coverage, buffer = layer(output,
                                             None,
                                             pos_emb,
                                             attn_mask,
                                             None,
                                             incremental=True,
                                             incremental_cache=buffer)
            decoder_state.update_attention_buffer(buffer, i)

        # Final normalization
        output = self.postprocess_layer(output)

        return output, decoder_state
Пример #6
0
    def process_embedding(self, input, input_lang=None):

        input_ = input

        emb = embedded_dropout(
            self.word_lut,
            input_,
            dropout=self.word_dropout if self.training else 0)
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)

        if self.use_language_embedding:
            lang_emb = self.language_embeddings(input_lang)  # B x H or 1 x H
            if self.language_embedding_type == 'sum':
                emb = emb + lang_emb.unsqueeze(1)
            elif self.language_embedding_type == 'concat':
                lang_emb = lang_emb.unsqueeze(1).expand_as(emb)
                concat_emb = torch.cat([emb, lang_emb], dim=-1)
                emb = torch.relu(self.projector(concat_emb))
            else:
                raise NotImplementedError
        return emb
Пример #7
0
    def forward(self, input, input_pos=None, input_lang=None, streaming=False, **kwargs):
        """
        Inputs Shapes:
            input: batch_size x src_len (wanna tranpose)
        Outputs Shapes:
            out: batch_size x src_len x d_model
            mask_src
        """

        """ Embedding: batch_size x src_len x d_model """
        if self.input_type == "text":
            bsz_first_input = input
            input = input.transpose(0, 1)
            # mask_src = input.eq(onmt.constants.PAD).unsqueeze(1)  # batch_size x src_len x 1 for broadcasting

            dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1)

            if streaming:
                streaming_state = kwargs.get('streaming_state', None)
                mems = streaming_state.src_mems
                # mem_len = streaming_state.src_mems[0].size(0)
                # mem_len = streaming_state.prev_src_mem_size
                mem_len = mems[0].size(0) if mems is not None else 0
                input_length = kwargs.get('src_lengths', None)
                streaming_state = kwargs.get('streaming_state', None)
                mask_src = self.create_stream_mask(input, input_length, mem_len)
                mask_src = mask_src.unsqueeze(2)
            else:
                mem_len = 0
                mask_src = input.eq(onmt.constants.PAD).unsqueeze(0)  # batch_size x src_len x 1 for broadcasting
                mems = None

            emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)

            """ Adding language embeddings """
            if self.use_language_embedding:
                assert self.language_embedding is not None
                # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H
                if self.language_embedding_type in ['sum', 'all_sum']:
                    lang_emb = self.language_embedding(input_lang)
                    # print(lang_emb.size(), emb.size())
                    emb = emb + lang_emb.unsqueeze(0)

        else:
            if streaming:
                raise NotImplementedError

            if not self.cnn_downsampling:
                mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq(onmt.constants.PAD).unsqueeze(0)
                dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD).unsqueeze(1)
                input = input.narrow(2, 1, input.size(2) - 1)
                emb = self.audio_trans(input.contiguous().view(-1, input.size(2))).view(input.size(0),
                                                                                        input.size(1), -1)
                emb = emb.type_as(input)
            else:
                long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD)
                input = input.narrow(2, 1, input.size(2) - 1)

                # first resizing to fit the CNN format
                input = input.view(input.size(0), input.size(1), -1, self.channels)
                input = input.permute(0, 3, 1, 2)

                input = self.audio_trans(input)
                input = input.permute(0, 2, 1, 3).contiguous()
                input = input.view(input.size(0), input.size(1), -1)
                # print(input.size())
                input = self.linear_trans(input)

                mask_src = long_mask[:, 0:input.size(1) * 4:4].transpose(0, 1).unsqueeze(0)
                dec_attn_mask = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1)
                # the size seems to be B x T ?
                emb = input

            emb = emb.transpose(0, 1)
            input = input.transpose(0, 1)
            abs_pos = None
            mem_len = 0
            mems = None

        if self.unidirectional:
            qlen = input.size(0)
            klen = qlen + mem_len
            attn_mask_src = torch.triu(
                emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None]

            pad_mask = mask_src

            mask_src = pad_mask + attn_mask_src
            # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0)
            mask_src = mask_src.gt(0)

        if onmt.constants.torch_version >= 1.2:
            mask_src = mask_src.bool()

        """ Scale the emb by sqrt(d_model) """
        emb = emb * math.sqrt(self.model_size)

        """ Adding positional encoding """
        qlen = input.size(0)
        klen = qlen + mem_len

        # Asynchronous positions: 2K+1 positions instead of K+1
        if self.unidirectional:
            pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype)
        else:
            pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype)

        # pos_emb has size 2T+1 x 1 x H
        pos_emb = self.positional_encoder(pos, bsz=input.size(1) if self.fast_self_attn else None)

        if self.learnable_position_encoding:
            raise NotImplementedError

        # B x T x H -> T x B x H
        context = emb

        if streaming:
            hids = [context]

        # Apply dropout to both context and pos_emb
        context = self.preprocess_layer(context)

        pos_emb = self.preprocess_layer(pos_emb)

        if self.reversible:
            context = torch.cat([context, context], dim=-1)

            assert streaming is not True, "Streaming and Reversible is not usable yet."
            # print(context.size(), pos_emb.size())
            context = ReversibleEncoderFunction.apply(context, pos_emb, self.layer_modules, mask_src)
        else:
            for i, layer in enumerate(self.layer_modules):
                # src_len x batch_size x d_model

                mems_i = mems[i] if mems is not None and streaming and self.max_memory_size > 0 else None
                context = layer(context, pos_emb, mask_src, mems=mems_i)

                if streaming:
                    hids.append(context)

        # final layer norm
        context = self.postprocess_layer(context)

        output_dict = defaultdict(lambda: None, {'context': context, 'src_mask': dec_attn_mask, 'src': input})

        if streaming:
            # streaming_state.prev_src_mem_size += sum(input_length.tolist())
            # streaming_state.prune_source_memory(self.max_memory_size)
            streaming_state.update_src_mems(hids, qlen)
            output_dict['streaming_state'] = streaming_state

        return output_dict
Пример #8
0
    def forward(self,
                input,
                context,
                src,
                input_pos=None,
                input_lang=None,
                streaming=False,
                **kwargs):
        """
                Inputs Shapes:
                    input: (Variable) batch_size x len_tgt (wanna tranpose)
                    context: (Variable) batch_size x src_len x d_model
                    mask_src (Tensor) batch_size x src_len
                Outputs Shapes:
                    out: batch_size x len_tgt x d_model
                    coverage: batch_size x len_tgt x src_len

                """
        """ Embedding: batch_size x len_tgt x d_model """
        input = input.transpose(0, 1)  # T x B
        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)
        emb = emb * math.sqrt(self.model_size)

        if streaming:
            src_lengths = kwargs.get("src_lengths", None)
            tgt_lengths = kwargs.get("tgt_lengths", None)
            streaming_state = kwargs.get("streaming_state")
            # mems = streaming_state.tgt_mems
            mem_len = streaming_state.prev_tgt_mem_size
            extra_context = streaming_state.extra_context
            extra_context_length = extra_context.size(
                0) if extra_context is not None else 0
            # mem_len = mems[0].size(0) if mems is not None else 0
        else:
            mem_len = 0
            mems = None
            extra_context = None

        if self.double_position:
            assert input_pos is not None
            tgt_len, bsz = input_pos.size(0), input_pos.size(1)
            input_pos_ = input_pos.view(-1).type_as(emb)
            abs_pos = self.positional_encoder(input_pos_).squeeze(1).view(
                tgt_len, bsz, -1)

            emb = emb + abs_pos

        if self.use_language_embedding:
            lang_emb = self.language_embeddings(input_lang)  # B x H or 1 x H
            if self.language_embedding_type == 'sum':
                emb = emb + lang_emb
            elif self.language_embedding_type == 'concat':
                # replace the bos embedding with the language
                bos_emb = lang_emb.expand_as(emb[0])
                emb[0] = bos_emb

                lang_emb = lang_emb.unsqueeze(0).expand_as(emb)
                concat_emb = torch.cat([emb, lang_emb], dim=-1)
                emb = torch.relu(self.projector(concat_emb))
            else:
                raise NotImplementedError

        if context is not None:
            if self.encoder_type == "audio":
                if not self.encoder_cnn_downsampling:
                    mask_src = src.narrow(2, 0, 1).squeeze(2).eq(
                        onmt.constants.PAD).unsqueeze(1)
                else:
                    long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(
                        onmt.constants.PAD)
                    mask_src = long_mask[:,
                                         0:context.size(0) * 4:4].unsqueeze(1)
            else:
                if streaming:
                    context_attn_mask = self.create_context_mask(
                        input, src, src_lengths, tgt_lengths,
                        extra_context_length)
                    mask_src = context_attn_mask.unsqueeze(0)
                else:
                    mask_src = src.eq(onmt.constants.PAD).unsqueeze(1)
        else:
            mask_src = None

        qlen = input.size(0)
        klen = qlen + mem_len
        # preparing self-attention mask. The input is either left or right aligned

        if streaming:
            dec_attn_mask = self.create_self_attn_mask(input, tgt_lengths,
                                                       mem_len)
        else:
            dec_attn_mask = torch.triu(emb.new_ones(qlen, klen),
                                       diagonal=1 + mem_len).byte()[:, :, None]
            pad_mask = input.eq(onmt.constants.PAD).byte()  # L x B

            dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0)
            dec_attn_mask = dec_attn_mask.gt(0)
            if onmt.constants.torch_version >= 1.2:
                dec_attn_mask = dec_attn_mask.bool()

        pos = torch.arange(klen - 1,
                           -1,
                           -1.0,
                           device=emb.device,
                           dtype=emb.dtype)

        output = self.preprocess_layer(emb.contiguous())

        if streaming:
            hids = [output]
            if extra_context is not None:
                context = torch.cat([extra_context, context], dim=0)
                # print(context.size(), context_attn_mask.size())

        for i, layer in enumerate(self.layer_modules):
            # batch_size x src_len x d_model output, coverage = layer(output, context, pos_emb, self.r_w_bias,
            # self.r_r_bias, dec_attn_mask, mask_src)
            # mems_i = mems[i] if mems is not None and streaming and
            # self.stream_context in ['local', 'global'] else None
            if streaming:
                buffer = streaming_state.tgt_buffer[i]
                output, coverage, buffer = layer(output,
                                                 context,
                                                 dec_attn_mask,
                                                 context_attn_mask,
                                                 incremental=True,
                                                 incremental_cache=buffer,
                                                 reuse_source=False)
                streaming_state.tgt_buffer[i] = buffer
            else:
                output, coverage, _ = layer(output, context, dec_attn_mask,
                                            mask_src)
                # if streaming:
                #     hids.append(output)

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        output = self.postprocess_layer(output)

        output_dict = {
            'hidden': output,
            'coverage': coverage,
            'context': context
        }
        output_dict = defaultdict(lambda: None, output_dict)

        if streaming:
            streaming_state.prev_tgt_mem_size += sum(tgt_lengths.tolist())
            streaming_state.prune_target_memory(self.max_memory_size)

            # if we use the extra context: keep the last context
            if self.extra_context_size > 0:
                extra_context = context[-self.extra_context_size:].detach()
                streaming_state.extra_context = extra_context

            # if self.stream_context in ['local', 'global']:
            #     streaming_state.update_tgt_mems(hids, qlen)
            output_dict['streaming_state'] = streaming_state

        return output_dict
Пример #9
0
    def forward(self,
                input,
                input_pos=None,
                input_lang=None,
                streaming=False,
                **kwargs):
        """
        Inputs Shapes:
            input: batch_size x src_len (wanna tranpose)
        Outputs Shapes:
            out: batch_size x src_len x d_model
            mask_src
        """
        """ Embedding: batch_size x src_len x d_model """
        if self.input_type == "text":
            bsz_first_input = input
            input = input.transpose(0, 1)
            # mask_src = input.eq(onmt.constants.PAD).unsqueeze(0)  # batch_size x src_len x 1 for broadcasting

            dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1)

            if streaming:
                raise NotImplementedError
                streaming_state = kwargs.get('streaming_state', None)
                mems = streaming_state.src_mems
                # mem_len = streaming_state.src_mems[0].size(0)
                mem_len = streaming_state.prev_src_mem_size
                input_length = kwargs.get('src_lengths', None)
                streaming_state = kwargs.get('streaming_state', None)
                mask_src = self.create_stream_mask(input, input_length,
                                                   mem_len)
                mask_src = mask_src.unsqueeze(2)
            else:
                mem_len = 0
                mask_src = input.eq(onmt.constants.PAD).unsqueeze(
                    0)  # batch_size x src_len x 1 for broadcasting
                mems = None

            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)

            if self.double_position:
                assert input_pos is not None
                # flatten
                src_len, bsz = input_pos.size(0), input_pos.size(1)
                input_pos_ = input_pos.contiguous().view(-1).type_as(emb)
                abs_pos = self.positional_encoder(input_pos_)
                abs_pos = abs_pos.squeeze(1).view(src_len, bsz, -1)

            else:
                abs_pos = None
            """ Adding language embeddings """
            if self.use_language_embedding:
                assert self.language_embedding is not None
                # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H
                if self.language_embedding_type in ['sum', 'all_sum']:
                    lang_emb = self.language_embedding(input_lang)
                    emb = emb + lang_emb.unsqueeze(1)

        else:
            if streaming:
                raise NotImplementedError

            if not self.cnn_downsampling:
                mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq(
                    onmt.constants.PAD).unsqueeze(0)
                dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
                input = input.narrow(2, 1, input.size(2) - 1)
                emb = self.audio_trans(input.contiguous().view(
                    -1, input.size(2))).view(input.size(0), input.size(1), -1)
            else:
                long_mask = input.narrow(2, 0,
                                         1).squeeze(2).eq(onmt.constants.PAD)
                input = input.narrow(2, 1, input.size(2) - 1)

                # first resizing to fit the CNN format
                input = input.view(input.size(0), input.size(1), -1,
                                   self.channels)
                input = input.permute(0, 3, 1, 2)

                input = self.audio_trans(input)
                input = input.permute(0, 2, 1, 3).contiguous()
                input = input.view(input.size(0), input.size(1), -1)
                # print(input.size())
                input = self.linear_trans(input)

                mask_src = long_mask[:, 0:input.size(1) *
                                     4:4].transpose().unsqueeze(0)
                dec_attn_mask = long_mask[:,
                                          0:input.size(1) * 4:4].unsqueeze(1)
                # the size seems to be B x T ?
                emb = input

            emb = emb.transpose(0, 1)
            input = input.transpose(0, 1)
            abs_pos = None
            mem_len = 0

        if onmt.constants.torch_version >= 1.2:
            mask_src = mask_src.bool()
        """ Scale the emb by sqrt(d_model) """
        emb = emb * math.sqrt(self.model_size)

        if self.double_position and abs_pos is not None:
            # adding position encoding
            emb = emb + abs_pos
        """ Adding positional encoding """
        qlen = input.size(0)
        klen = qlen + mem_len

        # Asynchronous positions: 2K+1 positions instead of K+1

        # because the batch dimension is lacking

        # B x T x H -> T x B x H
        context = emb

        # Apply dropout to both context and pos_emb
        context = self.preprocess_layer(context)

        for i, layer in enumerate(self.layer_modules):
            # src_len x batch_size x d_model

            if streaming:
                buffer = streaming_state.src_buffer[i]
                context, buffer = layer(context,
                                        mask_src,
                                        incremental=True,
                                        incremental_cache=buffer)
                streaming_state.src_buffer[i] = buffer
            else:
                context = layer(context, mask_src)

        # last layer norm
        context = self.postprocess_layer(context)

        output_dict = defaultdict(lambda: None, {
            'context': context,
            'src_mask': dec_attn_mask,
            'src': input
        })

        if streaming:
            streaming_state.prev_src_mem_size += sum(input_length.tolist())
            streaming_state.prune_source_memory(self.max_memory_size)
            # streaming_state.update_src_mems(hids, qlen)
            output_dict['streaming_state'] = streaming_state

        return output_dict
Пример #10
0
    def forward(self, input, context, src):
        """
        Inputs Shapes: 
            input: (Variable) batch_size x len_tgt (wanna tranpose)
            context: (Variable) batch_size x len_src x d_model
            mask_src (Tensor) batch_size x len_src
        Outputs Shapes:
            out: batch_size x len_tgt x d_model
            coverage: batch_size x len_tgt x len_src
            
        """
        """ Embedding: batch_size x len_tgt x d_model """

        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)
        if isinstance(emb, tuple):
            emb = emb[0]
        emb = self.preprocess_layer(emb)

        mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1)

        pad_mask_src = torch.autograd.Variable(src.data.ne(onmt.constants.PAD))

        len_tgt = input.size(1)
        mask_tgt = input.data.eq(
            onmt.constants.PAD).unsqueeze(1) + self.mask[:len_tgt, :len_tgt]
        mask_tgt = torch.gt(mask_tgt, 0)

        output = emb.contiguous()

        pad_mask_tgt = torch.autograd.Variable(
            input.data.ne(onmt.constants.PAD))  # batch_size x len_src
        pad_mask_src = torch.autograd.Variable(1 - mask_src.squeeze(1))

        memory_bank = None

        for i, layer in enumerate(self.layer_modules):

            if len(self.layer_modules
                   ) - i <= onmt.constants.checkpointing and self.training:

                output, memory_bank, coverage = checkpoint(
                    custom_layer(layer), output, context, memory_bank,
                    mask_tgt, mask_src, pad_mask_tgt,
                    pad_mask_src)  # batch_size x len_src x d_model

            else:
                output, memory_bank, coverage = layer(
                    output, context, memory_bank, mask_tgt, mask_src,
                    pad_mask_tgt,
                    pad_mask_src)  # batch_size x len_src x d_model

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        output = self.postprocess_layer(output)

        return output, coverage
    def step(self, input, decoder_state):

        src = decoder_state.src.transpose(
            0, 1) if decoder_state.src is not None else None
        tgt = input
        tgt_lang = decoder_state.tgt_lang
        src_lang = decoder_state.src_lang
        # print(src.size(), tgt.size())
        # print(src_lang, tgt_lang)

        tgt_len = tgt.size(1)
        src_len = src.size(1)
        bsz = tgt.size(0)

        # Embedding stage (and scale the embedding)
        src_emb = embedded_dropout(self.src_embedding, src, dropout=self.word_dropout if self.training else 0) \
                  * math.sqrt(self.model_size)
        tgt_emb = embedded_dropout(self.tgt_embedding, tgt, dropout=self.word_dropout if self.training else 0) \
                  * math.sqrt(self.model_size)

        # Add position encoding
        src_emb = self.time_transformer(src_emb)
        tgt_emb = self.time_transformer(tgt_emb)

        if self.use_language_embedding:
            if self.language_embedding_type in ["sum", "all_sum"]:
                src_lang_emb = self.language_embeddings(src_lang)
                src_emb += src_lang_emb.unsqueeze(1)
                tgt_lang_emb = self.language_embeddings(tgt_lang)
                tgt_emb += tgt_lang_emb.unsqueeze(1)

        # concatenate embedding
        emb = torch.cat([src_emb, tgt_emb], dim=1)  # L x batch_size x H

        # prepare self-attention mask
        # For the source: we have two different parts
        # [1 x src_len x batch_size]
        # mask_src_src = src.eq(onmt.constants.PAD).unsqueeze(0).byte()
        # src_pad_mask = mask_src_src
        # # Attention from src to target: everything is padded
        # mask_src_tgt = mask_src_src.new_ones(1, 1, 1).expand(src_len, tgt_len, bsz)
        # # [src_len x L x batch_size]
        # mask_src = torch.cat([mask_src_src.expand(src_len, src_len, bsz), mask_src_tgt], dim=1)
        # mask_src = mask_src.bool()
        # mask_src_src = src.eq(onmt.constants.PAD).unsqueeze(1).byte()  # B x 1 x src_len
        # mask_src_tgt = mask_src_src.new_ones(bsz, src_len, tgt_len)  # bsz x src_len x tgt_len
        #
        # mask_src = torch.cat([mask_src_src.expand(bsz, src_len, src_len), mask_src_tgt], dim=-1)
        #
        # # For the target:
        # mask_tgt_tgt = tgt.eq(onmt.constants.PAD).byte().unsqueeze(1) + self.mask[:tgt_len, :tgt_len]
        # mask_tgt_tgt = torch.gt(mask_tgt_tgt, 0).byte()  # bsz x tgt_len x tgt_len
        #
        # mask_tgt_src = mask_tgt_tgt.new_zeros(bsz, tgt_len, src_len) + src.eq(onmt.constants.PAD).unsqueeze(1).byte()
        # mask_tgt = torch.cat([mask_tgt_src, mask_tgt_tgt], dim=-1)  # bsz x tgt_len x T
        # attn_mask = torch.cat([mask_src, mask_tgt], dim=1).bool()  # L x L x batch_size

        attn_mask = self.gen_mask(src, input)
        # seq = torch.cat([src, input], dim=-1)
        # seq_len = seq.size(1)
        # attn_mask = self.mask[:seq_len, :seq_len] + seq.eq(onmt.constants.PAD).byte().unsqueeze(1)
        # attn_mask = torch.gt(attn_mask, 0).bool()

        output = emb

        # Applying dropout and tranpose to T x B x H
        output = self.preprocess_layer(output).transpose(0, 1)

        # FORWARD PASS
        coverage = None
        for i, layer in enumerate(self.layer_modules):
            output, coverage = layer(output, None, attn_mask,
                                     None)  # context and context_mask are None

        # Final normalization
        output = self.postprocess_layer(output)

        output = output[-1:, :, :]

        output_dict = defaultdict(lambda: None)
        output_dict['hidden'] = output

        logprobs = self.generator[0](output_dict).squeeze(0)

        output_dict['src'] = decoder_state.src.transpose(0, 1)
        output_dict['log_prob'] = logprobs
        output_dict['coverage'] = logprobs.new(bsz, tgt_len, src_len).zero_()
        # buffers = decoder_state.attention_buffers
        # tgt_lang = decoder_state.tgt_lang
        # src = decoder_state.src.transpose(0, 1) if decoder_state.src is not None else None
        #
        # if decoder_state.concat_input_seq:
        #     if decoder_state.input_seq is None:
        #         decoder_state.input_seq = input
        #     else:
        #         # concatenate the last input to the previous input sequence
        #         decoder_state.input_seq = torch.cat([decoder_state.input_seq, input], 0)
        #
        #     # For Transformer, both inputs are assumed as B x T (batch first)
        #     input = decoder_state.input_seq.transpose(0, 1)
        #     src = decoder_state.src.transpose(0, 1) if decoder_state.src is not None else None
        #
        # if input.size(1) > 1:
        #     input_ = input[:, -1].unsqueeze(1)
        # else:
        #     input_ = input
        # """ Embedding: batch_size x 1 x d_model """
        # # check = input_.gt(self.word_lut.num_embeddings)
        # print(input.size())
        # emb = self.tgt_embedding(input_) * math.sqrt(self.model_size)
        #
        # """ Adding positional encoding """
        # emb = self.time_transformer(emb, t=input.size(1))
        #
        # if self.use_language_embedding:
        #     if self.language_embedding_type in ["sum", "all_sum"]:
        #
        #         tgt_lang_emb = self.language_embeddings(tgt_lang)
        #         emb += tgt_lang_emb.unsqueeze(1)
        #
        # emb = emb.transpose(0, 1)
        #
        # # attention mask For the target:
        # tgt_len = input.size(1)
        # bsz = input.size(0)
        # src_len = src.size(1)
        # mask_tgt_tgt = input.eq(onmt.constants.PAD).byte().unsqueeze(1) + self.mask[:tgt_len, :tgt_len]
        # mask_tgt_tgt = torch.gt(mask_tgt_tgt, 0).byte()  # bsz x tgt_len x tgt_len
        #
        # mask_tgt_src = mask_tgt_tgt.new_zeros(bsz, tgt_len, src_len) + src.eq(onmt.constants.PAD).unsqueeze(1).byte()
        #
        # mask_tgt = torch.cat([mask_tgt_src, mask_tgt_tgt], dim=-1)  # bsz x tgt_len x T
        #
        # # take the last element of the 'target sequence' for the mask
        # attn_mask = mask_tgt[:, -1, :].unsqueeze(1).bool()
        #
        # output = emb
        #
        # for i, layer in enumerate(self.layer_modules):
        #     buffer = buffers[i] if i in buffers else None
        #     assert (output.size(0) == 1)
        #
        #     output, coverage, buffer = layer.step(output, None, attn_mask, None, buffer=buffer)
        #
        #     decoder_state.update_attention_buffer(buffer, i)
        #
        # # Final normalization

        # output_dict = defaultdict(lambda: None)
        # output_dict['hidden'] = output
        #
        # logprobs = self.generator[0](output_dict).squeeze(0)
        #
        # output_dict['src'] = decoder_state.src.transpose(0, 1)
        # output_dict['log_prob'] = logprobs
        # output_dict['coverage'] = logprobs.new(bsz, tgt_len, src_len).zero_()

        return output_dict
Пример #12
0
    def forward(self,
                input,
                context,
                src,
                input_pos=None,
                src_lang=None,
                tgt_lang=None,
                streaming=False,
                **kwargs):
        """
                Inputs Shapes:
                    input: (Variable) batch_size x len_tgt (wanna tranpose)
                    context: (Variable) batch_size x src_len x d_model
                    mask_src (Tensor) batch_size x src_len
                Outputs Shapes:
                    out: batch_size x len_tgt x d_model
                    coverage: batch_size x len_tgt x src_len

                """
        """ Embedding: batch_size x len_tgt x d_model """
        input = input.transpose(0, 1)  # T x B
        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)
        emb = emb * math.sqrt(self.model_size)

        mem_len = 0
        mems = None
        extra_context = None

        if self.use_language_embedding:
            lang_emb = self.language_embeddings(tgt_lang)  # B x H or 1 x H
            if self.language_embedding_type == 'sum':
                emb = emb + lang_emb
            elif self.language_embedding_type == 'concat':
                lang_emb = lang_emb.unsqueeze(0).expand_as(emb)
                concat_emb = torch.cat([emb, lang_emb], dim=-1)
                emb = torch.relu(self.projector(concat_emb))
            else:
                raise NotImplementedError

        if context is not None:
            if self.encoder_type == "audio":
                if not self.encoder_cnn_downsampling:
                    mask_src = src.narrow(2, 0, 1).squeeze(2).eq(
                        onmt.constants.PAD).unsqueeze(1)
                else:
                    long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(
                        onmt.constants.PAD)
                    mask_src = long_mask[:,
                                         0:context.size(0) * 4:4].unsqueeze(1)
            else:
                mask_src = src.eq(onmt.constants.PAD).unsqueeze(1)
        else:
            mask_src = None

        qlen = input.size(0)
        klen = qlen + mem_len
        # preparing self-attention mask. The input must be left-aligned

        dec_attn_mask = torch.triu(emb.new_ones(qlen, klen),
                                   diagonal=1 + mem_len).byte()[:, :, None]

        dec_attn_mask = dec_attn_mask.bool()

        # pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype)
        if not self.learnable_position_encoding:
            pos = torch.arange(klen - 1,
                               -1,
                               -1.0,
                               device=emb.device,
                               dtype=emb.dtype)
            pos_emb = self.positional_encoder(pos, bsz=input.size(1))
            pos_emb = self.preprocess_layer(pos_emb)
        else:
            range_vec = torch.arange(klen, device=emb.device)
            range_mat = range_vec.unsqueeze(-1).expand(-1,
                                                       klen).transpose(0, 1)
            distance_mat = range_vec - range_mat.transpose(0, 1)
            distance_mat.clamp_(-self.max_pos_length,
                                self.max_pos_length).add_(self.max_pos_length)
            pos_emb = distance_mat

        # pos_emb = self.positional_encoder(pos, bsz=input.size(1))
        output = self.preprocess_layer(emb.contiguous())
        # pos_emb = self.preprocess_layer(pos_emb)

        lfv_vector, lid_logits = None, list()

        if self.mpw:
            src_lang = self.factor_embeddings(src_lang).squeeze(0)
            tgt_lang = self.factor_embeddings(tgt_lang).squeeze(0)
            assert src_lang.ndim == 1 and tgt_lang.ndim == 1

        for i, layer in enumerate(self.layer_modules):
            output, coverage, _ = layer(output,
                                        context,
                                        pos_emb,
                                        lfv_vector,
                                        dec_attn_mask,
                                        mask_src,
                                        src_lang=src_lang,
                                        tgt_lang=tgt_lang)

        output = self.postprocess_layer(output, factor=tgt_lang)

        output_dict = {
            'hidden': output,
            'coverage': coverage,
            'context': context,
            'lid_logits': lid_logits
        }
        output_dict = defaultdict(lambda: None, output_dict)

        return output_dict
Пример #13
0
    def forward(self,
                input,
                input_pos=None,
                input_lang=None,
                streaming=False,
                **kwargs):
        """
        Inputs Shapes:
            input: batch_size x src_len (wanna tranpose)
        Outputs Shapes:
            out: batch_size x src_len x d_model
            mask_src
        """
        """ Embedding: batch_size x src_len x d_model """
        bsz_first_input = input
        input = input.transpose(0, 1)

        dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1)

        mem_len = 0

        mems = None

        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)
        if self.early_emb_scale:
            """ Scale the emb by sqrt(d_model) """
            emb = emb * math.sqrt(self.model_size)
        """ Adding language embeddings """
        if self.use_language_embedding:
            assert self.language_embedding is not None
            # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H
            if self.language_embedding_type in ['sum', 'all_sum']:
                lang_emb = self.language_embedding(input_lang)
                emb = emb + lang_emb.unsqueeze(0)
        """ Adding positional encoding """
        qlen = input.size(0)
        klen = qlen + mem_len

        # Asynchronous positions: 2K+1 positions instead of K+1
        if not self.absolute_position_encoding:
            if not self.learnable_position_encoding:
                pos = torch.arange(klen - 1,
                                   -klen,
                                   -1.0,
                                   device=emb.device,
                                   dtype=emb.dtype)
                # pos_emb has size 2T+1 x 1 x H
                pos_emb = self.positional_encoder(pos, bsz=input.size(1))
                pos_emb = self.preprocess_layer(pos_emb)
            else:
                range_vec = torch.arange(klen, device=emb.device)
                range_mat = range_vec.unsqueeze(-1).expand(-1, klen).transpose(
                    0, 1)
                distance_mat = range_vec - range_mat.transpose(0, 1)
                distance_mat.clamp_(-self.max_pos_length,
                                    self.max_pos_length).add_(
                                        self.max_pos_length)
                pos_emb = distance_mat
                # pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device).long()
                # pos.clamp_(-self.max_pos_length, self.max_pos_length).add_(self.max_pos_length)
                # pos_emb = pos.unsqueeze(1)

            mask_src = input.eq(onmt.constants.PAD).unsqueeze(
                0)  # 1 x src_len x batch_size for broadcasting
        else:
            # Absolute position encoding from 0 -> n
            pos, pos_emb = None, None
            emb = self.positional_encoder(emb.transpose(0, 1)).transpose(0, 1)
            mask_src = bsz_first_input.eq(
                onmt.constants.PAD)  # batch_size x src_len

        if onmt.constants.torch_version >= 1.2:
            mask_src = mask_src.bool()

        if not self.early_emb_scale:
            """ Scale the emb by sqrt(d_model) """
            emb = emb * math.sqrt(self.model_size)

        # context size is now T x B x H
        context = self.preprocess_layer(emb)

        if self.reversible:
            context = reversible_encoder(self.layer_modules, context, pos_emb,
                                         mask_src)
        else:
            for i, layer in enumerate(self.layer_modules):
                # src_len x batch_size x d_model
                context = layer(context,
                                pos_emb,
                                mask_src,
                                src_lang=input_lang)
                # if self.checkpointing == 0 or self.training is False:
                #     context = layer(context, pos_emb, mask_src, src_lang=input_lang)
                # else:
                #     context = checkpoint(create_forward_function(layer), context, pos_emb, mask_src, input_lang)

        # final layer norm. we can consider this layer norm as a part of the output layer/function
        context = self.postprocess_layer(context)

        output_dict = defaultdict(lambda: None, {
            'context': context,
            'src_mask': dec_attn_mask,
            'src': input
        })

        return output_dict
Пример #14
0
    def forward(self,
                input,
                context,
                src,
                input_pos=None,
                input_lang=None,
                streaming=False,
                **kwargs):
        """
                Inputs Shapes:
                    input: (Variable) batch_size x len_tgt (wanna tranpose)
                    context: (Variable) batch_size x src_len x d_model
                    mask_src (Tensor) batch_size x src_len
                Outputs Shapes:
                    out: batch_size x len_tgt x d_model
                    coverage: batch_size x len_tgt x src_len

                """
        """ Embedding: batch_size x len_tgt x d_model """
        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)

        if self.use_language_embedding:
            lang_emb = self.language_embeddings(input_lang)  # B x H or 1 x H
            if self.language_embedding_type == 'sum':
                emb = emb + lang_emb
            elif self.language_embedding_type == 'concat':
                # replace the bos embedding with the language
                bos_emb = lang_emb.expand_as(emb[:, 0, :])
                emb[:, 0, :] = bos_emb

                lang_emb = lang_emb.unsqueeze(1).expand_as(emb)
                concat_emb = torch.cat([emb, lang_emb], dim=-1)
                emb = torch.relu(self.projector(concat_emb))
            else:
                raise NotImplementedError

        if context is not None:
            if self.encoder_type == "audio":
                if not self.encoder_cnn_downsampling:
                    mask_src = src.data.narrow(2, 0, 1).squeeze(2).eq(
                        onmt.constants.PAD).unsqueeze(1)
                else:
                    long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(
                        onmt.constants.PAD)
                    mask_src = long_mask[:,
                                         0:context.size(0) * 4:4].unsqueeze(1)
            else:

                mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1)
        else:
            mask_src = None

        len_tgt = input.size(1)
        mask_tgt = torch.triu(emb.new_ones(len_tgt, len_tgt),
                              diagonal=1).byte().unsqueeze(0)

        mask_tgt = mask_tgt.bool()

        time_embedding = self.positional_encoder.get_positional_embeddings(emb)

        output = self.preprocess_layer(emb.transpose(0, 1).contiguous())

        for i in range(self.max_layers):
            layer_tensor = torch.LongTensor([i]).to(output.device)
            layer_embedding = self.layer_embeddings(layer_tensor)

            output, coverage, _ = self.universal_layer(output, time_embedding,
                                                       layer_embedding,
                                                       context, mask_tgt,
                                                       mask_src)

        # last layer norm
        output = self.postprocess_layer(output)

        output_dict = {
            'hidden': output,
            'coverage': coverage,
            'context': context
        }
        output_dict = defaultdict(lambda: None, output_dict)

        return output_dict
    def forward(self,
                input,
                input_pos=None,
                input_lang=None,
                streaming=False,
                **kwargs):
        """
        Inputs Shapes:
            input: batch_size x src_len (wanna tranpose)
        Outputs Shapes:
            out: batch_size x src_len x d_model
            mask_src
        """
        """ Embedding: batch_size x src_len x d_model """
        if self.input_type == "text":
            bsz_first_input = input
            input = input.transpose(0, 1)
            # mask_src = input.eq(onmt.constants.PAD).unsqueeze(0)  # batch_size x src_len x 1 for broadcasting

            dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1)

            if streaming:
                raise NotImplementedError
                streaming_state = kwargs.get('streaming_state', None)
                mems = streaming_state.src_mems
                # mem_len = streaming_state.src_mems[0].size(0)
                mem_len = streaming_state.prev_src_mem_size
                input_length = kwargs.get('src_lengths', None)
                streaming_state = kwargs.get('streaming_state', None)
                mask_src = self.create_stream_mask(input, input_length,
                                                   mem_len)
                mask_src = mask_src.unsqueeze(2)
            else:
                mem_len = 0
                mask_src = input.eq(onmt.constants.PAD).unsqueeze(
                    0)  # batch_size x src_len x 1 for broadcasting
                mems = None

            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)

            if self.double_position:
                assert input_pos is not None
                # flatten
                src_len, bsz = input_pos.size(0), input_pos.size(1)
                input_pos_ = input_pos.contiguous().view(-1).type_as(emb)
                abs_pos = self.positional_encoder(input_pos_)
                abs_pos = abs_pos.squeeze(1).view(src_len, bsz, -1)

            else:
                abs_pos = None
            """ Adding language embeddings """
            if self.use_language_embedding:
                assert self.language_embedding is not None
                # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H
                if self.language_embedding_type in ['sum', 'all_sum']:
                    lang_emb = self.language_embedding(input_lang)
                    emb = emb + lang_emb.unsqueeze(1)

        else:
            if streaming:
                raise NotImplementedError

            if not self.cnn_downsampling:
                mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq(
                    onmt.constants.PAD).unsqueeze(0)
                dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
                input = input.narrow(2, 1, input.size(2) - 1)
                emb = self.audio_trans(input.contiguous().view(
                    -1, input.size(2))).view(input.size(0), input.size(1), -1)
            else:
                long_mask = input.narrow(2, 0,
                                         1).squeeze(2).eq(onmt.constants.PAD)
                input = input.narrow(2, 1, input.size(2) - 1)

                # first resizing to fit the CNN format
                input = input.view(input.size(0), input.size(1), -1,
                                   self.channels)
                input = input.permute(0, 3, 1, 2)

                input = self.audio_trans(input)
                input = input.permute(0, 2, 1, 3).contiguous()
                input = input.view(input.size(0), input.size(1), -1)
                # print(input.size())
                input = self.linear_trans(input)

                mask_src = long_mask[:, 0:input.size(1) *
                                     4:4].transpose().unsqueeze(0)
                dec_attn_mask = long_mask[:,
                                          0:input.size(1) * 4:4].unsqueeze(1)
                # the size seems to be B x T ?
                emb = input

            emb = emb.transpose(0, 1)
            input = input.transpose(0, 1)
            abs_pos = None
            mem_len = 0

        if onmt.constants.torch_version >= 1.2:
            mask_src = mask_src.bool()
        """ Scale the emb by sqrt(d_model) """
        emb = emb * math.sqrt(self.model_size)

        if self.double_position and abs_pos is not None:
            # adding position encoding
            emb = emb + abs_pos
        """ Adding positional encoding """
        qlen = input.size(0)
        klen = qlen + mem_len

        # Asynchronous positions: 2K+1 positions instead of K+1
        pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device).long()

        # because the batch dimension is lacking
        pos_emb = self.positional_encoder(pos).unsqueeze(1)

        # B x T x H -> T x B x H
        context = emb

        # Apply dropout to both context and pos_emb
        context = self.preprocess_layer(context)

        pos_emb = self.preprocess_layer(pos_emb)

        for i, layer in enumerate(self.layer_modules):
            # src_len x batch_size x d_model

            if streaming:
                buffer = streaming_state.src_buffer[i]
                context, buffer = layer(context,
                                        pos_emb,
                                        mask_src,
                                        incremental=True,
                                        incremental_cache=buffer)
                streaming_state.src_buffer[i] = buffer
            else:
                context = layer(context, pos_emb, mask_src)

        # last layer norm
        context = self.postprocess_layer(context)

        output_dict = defaultdict(lambda: None, {
            'context': context,
            'src_mask': dec_attn_mask,
            'src': input
        })

        if streaming:
            streaming_state.prev_src_mem_size += sum(input_length.tolist())
            streaming_state.prune_source_memory(self.max_memory_size)
            # streaming_state.update_src_mems(hids, qlen)
            output_dict['streaming_state'] = streaming_state

        return output_dict


# class RelativeTransformerDecoder(TransformerDecoder):
#
#     def __init__(self, opt, dicts, positional_encoder, language_embeddings=None, ignore_source=False):
#
#         self.death_rate = opt.death_rate
#         self.double_position = opt.double_position
#         self.max_memory_size = opt.max_memory_size
#         self.stream_context = opt.stream_context
#         self.extra_context_size = opt.extra_context_size
#
#         # build_modules will be called from the inherited constructor
#         super(RelativeTransformerDecoder, self).__init__(opt, dicts,
#                                                          positional_encoder,
#                                                          language_embeddings,
#                                                          ignore_source,
#                                                          allocate_positions=False)
#         self.positional_encoder = SinusoidalPositionalEmbedding(opt.model_size)
#         self.d_head = self.model_size // self.n_heads
#         # Parameters for the position biases
#         self.r_w_bias = nn.Parameter(torch.Tensor(self.n_heads, self.d_head))
#         self.r_r_bias = nn.Parameter(torch.Tensor(self.n_heads, self.d_head))
#
#     def renew_buffer(self, new_len):
#         return
#
#     def build_modules(self):
#
#         e_length = expected_length(self.layers, self.death_rate)
#
#         print("* Transformer Decoder with Relative Attention with %.2f expected layers" % e_length)
#
#         self.layer_modules = nn.ModuleList()
#
#         for l in range(self.layers):
#             # linearly decay the death rate
#             death_r = (l + 1.0) / self.layers * self.death_rate
#
#             block = RelativeTransformerDecoderLayer(self.n_heads, self.model_size,
#                                                     self.dropout, self.inner_size, self.attn_dropout,
#                                                     variational=self.variational_dropout, death_rate=death_r)
#
#             self.layer_modules.append(block)
#
#     def process_embedding(self, input, input_lang=None):
#
#         return input
#
#     def create_context_mask(self, input, src, src_lengths, tgt_lengths, extra_context_length=0):
#         """
#         Generate the mask so that part of the target attends to a part of the source
#         :param extra_context_length:
#         :param input:
#         :param src:
#         :param src_lengths:
#         :param tgt_lengths:
#         :return:
#         """
#
#         mask = None
#
#         if self.stream_context == 'global':
#             # Global context: one target attends to everything in the source
#             for (src_length, tgt_length) in zip(src_lengths, tgt_lengths):
#
#                 if mask is None:
#                     prev_src_length = 0
#                     prev_tgt_length = 0
#                 else:
#                     prev_src_length, prev_tgt_length = mask.size(1), mask.size(0)
#
#                 # current sent attend to current src sent and all src in the past
#                 current_mask = input.new_zeros(tgt_length, src_length + prev_src_length)
#
#                 # the previous target cannot attend to the current source
#                 if prev_tgt_length > 0:
#                     prev_mask = input.new_ones(prev_tgt_length, src_length)
#                     prev_mask = torch.cat([mask, prev_mask], dim=-1)
#                 else:
#                     prev_mask = None
#
#                 # the output mask has two parts: the prev and the current
#                 if prev_mask is not None:
#                     mask = torch.cat([prev_mask, current_mask], dim=0)
#                 else:
#                     mask = current_mask
#
#         # elif self.stream_context == 'local_xl':
#         #     # Local extra context: only attends to the aligned context + extra mem
#         #     # This mode ensures that all target sentences have the same memory, not uneven like "global"
#         #
#         #     for (src_length, tgt_length) in zip(src_lengths, tgt_lengths):
#         #
#         #         # First: we read the existing mask to know where we are
#         #         if mask is None:
#         #             prev_src_length = 0
#         #             prev_tgt_length = 0
#         #         else:
#         #             prev_src_length, prev_tgt_length = mask.size(1), mask.size(0)
#         #
#         #             # current tgt sent attend to only current src sent
#         #             if prev_src_length > 0:
#         #                 current_mask = torch.cat([input.new_ones(tgt_length, prev_src_length - extra_context_length),
#         #                                           input.new_zeros(tgt_length, src_length + extra_context_length)], dim=-1)
#         #             else:
#         #                 current_mask = input.new_zeros(tgt_length, src_length + extra_context_length)
#         #
#         #                 # the previous target cannot attend to the current source
#         #                 if prev_tgt_length > 0:
#         #                     prev_mask = input.new_ones(prev_tgt_length, src_length)
#         #                     prev_mask = torch.cat([mask, prev_mask], dim=-1)
#         #                 else:
#         #                     prev_mask = None
#         #
#         #                 # the output mask has two parts: the prev and the current
#         #                 if prev_mask is not None:
#         #                     mask = torch.cat([prev_mask, current_mask], dim=0)
#         #                 else:
#         #                     mask = current_mask
#
#         elif self.stream_context in ['local', 'limited']:
#             # Local context: only attends to the aligned context
#             for (src_length, tgt_length) in zip(src_lengths, tgt_lengths):
#
#                 if mask is None:
#                     prev_src_length = 0
#                     prev_tgt_length = 0
#                 else:
#                     prev_src_length, prev_tgt_length = mask.size(1), mask.size(0)
#
#                 # current tgt sent attend to only current src sent
#                 if prev_src_length > 0:
#                     current_mask = torch.cat([input.new_ones(tgt_length, prev_src_length - extra_context_length),
#                                               input.new_zeros(tgt_length, src_length + extra_context_length)], dim=-1)
#                 else:
#                     current_mask = input.new_zeros(tgt_length, src_length + extra_context_length)
#
#                 # the previous target cannot attend to the current source
#                 if prev_tgt_length > 0:
#                     prev_mask = input.new_ones(prev_tgt_length, src_length)
#                     prev_mask = torch.cat([mask, prev_mask], dim=-1)
#                 else:
#                     prev_mask = None
#
#                 # the output mask has two parts: the prev and the current
#                 if prev_mask is not None:
#                     mask = torch.cat([prev_mask, current_mask], dim=0)
#                 else:
#                     mask = current_mask
#
#         mask = mask.bool()
#         return mask
#
#     def create_self_attn_mask(self, input, tgt_lengths, prev_tgt_mem_size):
#         """
#         Create a mask for the target words attending to the past
#         :param input:
#         :param tgt_lengths:
#         :param prev_tgt_mem_size:
#         :return:
#         """
#
#         if self.stream_context in ['local', 'global']:
#             qlen = sum(tgt_lengths.tolist())
#             mlen = prev_tgt_mem_size
#             klen = qlen + mlen
#             mask = torch.triu(input.new_ones(qlen, klen), diagonal=1 + mlen).bool()[:, :, None]
#         elif self.stream_context in ['limited']:
#
#             # past_length = prev_tgt_mem_size
#             mask = None
#             # assert prev_tgt_mem_size == 0, "This model is limited and doesn't accept memory"
#
#             for length in tgt_lengths:
#
#                 past_length = mask.size(0) if mask is not None else 0
#
#                 if past_length > 0:
#                     # don't look at the past
#                     past_mask = input.new_ones(length, past_length)
#                 else:
#                     past_mask = None
#
#                 # pay attention to the past words in the current sentence
#                 current_mask = torch.triu(input.new_ones(length, length), diagonal=1)
#
#                 if past_mask is not None:
#                     current_mask = torch.cat([past_mask, current_mask], dim=1)
#
#                 if mask is None:
#                     mask = current_mask
#                 else:
#                     no_future_mask = input.new_ones(past_length, length)
#                     mask = torch.cat([mask, no_future_mask], dim=1)
#                     mask = torch.cat([mask, current_mask], dim=0)
#
#             mask = mask.bool().unsqueeze(-1)
#
#         return mask
#
#     # TODO: merging forward_stream and forward
#     # TODO: write a step function for encoder
#
#     def forward(self, input, context, src, input_pos=None, input_lang=None, streaming=False, **kwargs):
#         """
#                 Inputs Shapes:
#                     input: (Variable) batch_size x len_tgt (wanna tranpose)
#                     context: (Variable) batch_size x src_len x d_model
#                     mask_src (Tensor) batch_size x src_len
#                 Outputs Shapes:
#                     out: batch_size x len_tgt x d_model
#                     coverage: batch_size x len_tgt x src_len
#
#                 """
#
#         """ Embedding: batch_size x len_tgt x d_model """
#         input = input.transpose(0, 1)  # T x B
#         emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)
#         emb = emb * math.sqrt(self.model_size)
#
#         if streaming:
#             src_lengths = kwargs.get("src_lengths", None)
#             tgt_lengths = kwargs.get("tgt_lengths", None)
#             streaming_state = kwargs.get("streaming_state")
#             # mems = streaming_state.tgt_mems
#             mem_len = streaming_state.prev_tgt_mem_size
#             extra_context = streaming_state.extra_context
#             extra_context_length = extra_context.size(0) if extra_context is not None else 0
#             # mem_len = mems[0].size(0) if mems is not None else 0
#         else:
#             mem_len = 0
#             mems = None
#             extra_context = None
#
#         if self.double_position:
#             assert input_pos is not None
#             tgt_len, bsz = input_pos.size(0), input_pos.size(1)
#             input_pos_ = input_pos.view(-1).type_as(emb)
#             abs_pos = self.positional_encoder(input_pos_).squeeze(1).view(tgt_len, bsz, -1)
#
#             emb = emb + abs_pos
#
#         if self.use_language_embedding:
#             lang_emb = self.language_embeddings(input_lang)  # B x H or 1 x H
#             if self.language_embedding_type == 'sum':
#                 emb = emb + lang_emb
#             elif self.language_embedding_type == 'concat':
#                 # replace the bos embedding with the language
#                 bos_emb = lang_emb.expand_as(emb[0])
#                 emb[0] = bos_emb
#
#                 lang_emb = lang_emb.unsqueeze(0).expand_as(emb)
#                 concat_emb = torch.cat([emb, lang_emb], dim=-1)
#                 emb = torch.relu(self.projector(concat_emb))
#             else:
#                 raise NotImplementedError
#
#         if context is not None:
#             if self.encoder_type == "audio":
#                 if not self.encoder_cnn_downsampling:
#                     mask_src = src.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD).unsqueeze(1)
#                 else:
#                     long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD)
#                     mask_src = long_mask[:, 0:context.size(0) * 4:4].unsqueeze(1)
#             else:
#                 if streaming:
#                     context_attn_mask = self.create_context_mask(input, src,
#                                                                  src_lengths, tgt_lengths,
#                                                                  extra_context_length)
#                     mask_src = context_attn_mask.unsqueeze(0)
#                 else:
#                     mask_src = src.eq(onmt.constants.PAD).unsqueeze(1)
#         else:
#             mask_src = None
#
#         qlen = input.size(0)
#         klen = qlen + mem_len
#         # preparing self-attention mask. The input is either left or right aligned
#
#         if streaming:
#             dec_attn_mask = self.create_self_attn_mask(input, tgt_lengths, mem_len)
#         else:
#             dec_attn_mask = torch.triu(
#                 emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None]
#             pad_mask = input.eq(onmt.constants.PAD).byte()  # L x B
#
#             dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0)
#             dec_attn_mask = dec_attn_mask.gt(0)
#             if onmt.constants.torch_version >= 1.2:
#                 dec_attn_mask = dec_attn_mask.bool()
#
#         pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype)
#
#         pos_emb = self.positional_encoder(pos)
#
#         output = self.preprocess_layer(emb.contiguous())
#
#         if streaming:
#             hids = [output]
#             if extra_context is not None:
#                 context = torch.cat([extra_context, context], dim=0)
#                 # print(context.size(), context_attn_mask.size())
#
#         pos_emb = self.preprocess_layer(pos_emb)
#
#         for i, layer in enumerate(self.layer_modules):
#             # batch_size x src_len x d_model output, coverage = layer(output, context, pos_emb, self.r_w_bias,
#             # self.r_r_bias, dec_attn_mask, mask_src)
#             # mems_i = mems[i] if mems is not None and streaming and
#             # self.stream_context in ['local', 'global'] else None
#             if streaming:
#                 buffer = streaming_state.tgt_buffer[i]
#                 output, coverage, buffer = layer(output, context, pos_emb, dec_attn_mask, context_attn_mask,
#                                                  incremental=True, incremental_cache=buffer, reuse_source=False)
#                 streaming_state.tgt_buffer[i] = buffer
#             else:
#                 output, coverage, _ = layer(output, context, pos_emb, dec_attn_mask, mask_src   )
#                 # if streaming:
#                 #     hids.append(output)
#
#         # From Google T2T
#         # if normalization is done in layer_preprocess, then it should also be done
#         # on the output, since the output can grow very large, being the sum of
#         # a whole stack of unnormalized layer outputs.
#         output = self.postprocess_layer(output)
#
#         output_dict = {'hidden': output, 'coverage': coverage, 'context': context}
#         output_dict = defaultdict(lambda: None, output_dict)
#
#         if streaming:
#             streaming_state.prev_tgt_mem_size += sum(tgt_lengths.tolist())
#             streaming_state.prune_target_memory(self.max_memory_size)
#
#             # if we use the extra context: keep the last context
#             if self.extra_context_size > 0:
#                 extra_context = context[-self.extra_context_size:].detach()
#                 streaming_state.extra_context = extra_context
#
#             # if self.stream_context in ['local', 'global']:
#             #     streaming_state.update_tgt_mems(hids, qlen)
#             output_dict['streaming_state'] = streaming_state
#
#         return output_dict
#
#     def step(self, input, decoder_state, streaming=False):
#         """
#         Inputs Shapes:
#             input: (Variable) batch_size x len_tgt (wanna tranpose)
#             context: (Variable) batch_size x src_len x d_model
#             mask_src (Tensor) batch_size x src_len
#             buffer (List of tensors) List of batch_size * len_tgt-1 * d_model for self-attention recomputing
#         Outputs Shapes:
#             out: batch_size x len_tgt x d_model
#             coverage: batch_size x len_tgt x src_len
#
#         """
#
#         if streaming:
#             return self.step_streaming(input, decoder_state)
#
#         context = decoder_state.context
#         buffers = decoder_state.attention_buffers
#         lang = decoder_state.tgt_lang
#         mask_src = decoder_state.src_mask
#
#         if decoder_state.concat_input_seq:
#             if decoder_state.input_seq is None:
#                 decoder_state.input_seq = input
#             else:
#                 # concatenate the last input to the previous input sequence
#                 decoder_state.input_seq = torch.cat([decoder_state.input_seq, input], 0)
#             input = decoder_state.input_seq.transpose(0, 1)  # B x T
#
#         src = decoder_state.src.transpose(0, 1) if decoder_state.src is not None else None
#
#         # use the last value of input to continue decoding
#         if input.size(1) > 1:
#             input_ = input[:, -1].unsqueeze(1).transpose(0, 1)
#         else:
#             input_ = input.transpose(0, 1)
#
#         """ Embedding: batch_size x 1 x d_model """
#         emb = self.word_lut(input_) * math.sqrt(self.model_size)
#         input = input.transpose(0, 1)
#         klen = input.size(0)
#         # emb = self.word_lut(input) * math.sqrt(self.model_size)
#
#         if self.double_position:
#             input_pos = torch.arange(input.size(0), dtype=emb.dtype, device=emb.device)
#             input_pos = input_pos.unsqueeze(1).repeat(1, input.size(1))
#             tgt_len, bsz = input_pos.size(0), input_pos.size(1)
#             input_pos_ = input_pos.view(-1).type_as(emb)
#             abs_pos = self.positional_encoder(input_pos_).squeeze(1).view(tgt_len, bsz, -1)
#             emb = emb + abs_pos[-1:, :, :]
#
#         if self.use_language_embedding:
#             lang_emb = self.language_embeddings(lang)  # B x H
#
#             if self.language_embedding_type in ['sum', 'all_sum']:
#                 emb = emb + lang_emb
#             elif self.language_embedding_type == 'concat':
#                 if input.size(0) == 1:
#                     emb[0] = lang_emb
#
#                 lang_emb = lang_emb.unsqueeze(0).expand_as(emb)
#                 concat_emb = torch.cat([emb, lang_emb], dim=-1)
#                 emb = torch.relu(self.projector(concat_emb))
#             else:
#                 raise NotImplementedError
#
#         # prepare position encoding
#         qlen = emb.size(0)
#         mlen = klen - qlen
#
#         pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype)
#
#         pos_emb = self.positional_encoder(pos)
#
#         dec_attn_mask = torch.triu(
#             emb.new_ones(qlen, klen), diagonal=1 + mlen).byte()[:, :, None]
#
#         pad_mask = input.eq(onmt.constants.PAD).byte()  # L x B
#
#         dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0)
#         dec_attn_mask = dec_attn_mask.gt(0)
#
#         if onmt.constants.torch_version >= 1.2:
#             dec_attn_mask = dec_attn_mask.bool()
#
#         if context is not None:
#             if self.encoder_type == "audio":
#                 if not self.encoder_cnn_downsampling:
#                     mask_src = src.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD).unsqueeze(1)
#                 else:
#                     long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD)
#                     mask_src = long_mask[:, 0:context.size(0) * 4:4].unsqueeze(1)
#             else:
#
#                 mask_src = src.eq(onmt.constants.PAD).unsqueeze(1)
#         else:
#             mask_src = None
#
#         output = emb.contiguous()
#
#         for i, layer in enumerate(self.layer_modules):
#             buffer = buffers[i] if i in buffers else None
#             # assert (output.size(0) == 1)
#
#             # output, coverage, buffer = layer.step(output, context, pos_emb,
#             #                                       dec_attn_mask, mask_src, buffer=buffer)
#             output, coverage, buffer = layer(output, context, pos_emb, dec_attn_mask, mask_src,
#                                              incremental=True, incremental_cache=buffer)
#
#             decoder_state.update_attention_buffer(buffer, i)
#
#         output = self.postprocess_layer(output)
#         output = output[-1].unsqueeze(0)
#
#         output_dict = defaultdict(lambda: None)
#         output_dict['hidden'] = output
#         output_dict['coverage'] = coverage
#         output_dict['context'] = context
#
#         return output_dict
#
#     def step_streaming(self, input, decoder_state):
#         """Step function in streaming case"""
#
#         context = decoder_state.context
#         lang = decoder_state.tgt_lang
#         streaming_state = decoder_state.streaming_state
#
#         # for global model: push the context in
#
#         if decoder_state.concat_input_seq:
#             if decoder_state.input_seq is None:
#                 decoder_state.input_seq = input
#             else:
#                 # concatenate the last input to the previous input sequence
#                 decoder_state.input_seq = torch.cat([decoder_state.input_seq, input], 0)
#             input = decoder_state.input_seq.transpose(0, 1)  # B x T
#
#         src = decoder_state.src.transpose(0, 1) if decoder_state.src is not None else None
#
#         # use the last value of input to continue decoding
#         if input.size(1) > 1:
#             input_ = input[:, -1].unsqueeze(1).transpose(0, 1)
#         else:
#             input_ = input.transpose(0, 1)
#
#         emb = self.word_lut(input_) * math.sqrt(self.model_size)
#         input = input.transpose(0, 1)  # B x T to T x B
#         klen = input.size(0)
#
#         # If we start a new sentence to decode: reset the context memory
#         if klen == 1:
#             streaming_state.reset_context_memory()
#             if self.stream_context == 'limited':
#                 streaming_state.reset_target_memory()
#
#         if self.use_language_embedding:
#             lang_emb = self.language_embeddings(lang)  # B x H or 1 x H
#             if self.language_embedding_type == 'sum':
#                 emb = emb + lang_emb
#             elif self.language_embedding_type == 'concat':
#                 # replace the bos embedding with the language
#                 bos_emb = lang_emb.expand_as(emb[0])
#                 emb[0] = bos_emb
#
#                 lang_emb = lang_emb.unsqueeze(0).expand_as(emb)
#                 concat_emb = torch.cat([emb, lang_emb], dim=-1)
#                 emb = torch.relu(self.projector(concat_emb))
#             else:
#                 raise NotImplementedError
#
#         # need to manually definte src_lengths and tgt_lengths here
#         src_lengths = torch.LongTensor([context.size(0)])
#         tgt_lengths = torch.LongTensor([1])
#
#         if context is not None:
#             context_attn_mask = self.create_context_mask(input, src, src_lengths, tgt_lengths)
#             context_attn_mask = context_attn_mask.unsqueeze(0)
#         else:
#             context_attn_mask = None
#
#         dec_attn_mask = self.create_self_attn_mask(input, tgt_lengths, streaming_state.prev_tgt_mem_size)
#
#         dec_attn_mask = dec_attn_mask[:, -1:, :]
#
#         klen = 1 + streaming_state.prev_tgt_mem_size
#         pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype)
#
#         pos_emb = self.positional_encoder(pos)
#
#         output = emb
#
#         for i, layer in enumerate(self.layer_modules):
#             # T x B x d_model
#             buffer = streaming_state.tgt_buffer[i]
#             # output, coverage = layer(output, context, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask, mask_src)
#             # reuse_source = True if input.size(1) == 1 else False
#             reuse_source = True
#
#             # reuse source is True in this case because we can reuse the context ...
#             output, coverage, buffer = layer(output, context, pos_emb, dec_attn_mask, context_attn_mask,
#                                              incremental=True, incremental_cache=buffer, reuse_source=reuse_source)
#             streaming_state.tgt_buffer[i] = buffer
#
#         output = self.postprocess_layer(output)
#
#         streaming_state.prev_tgt_mem_size += 1
#         streaming_state.prune_target_memory(self.max_memory_size + input.size(0))
#
#         extra_context = context[-self.extra_context_size:].detach()
#
#         output_dict = defaultdict(lambda: None, {'hidden': output, 'coverage': coverage, 'context': context})
#         output_dict['streaming_state'] = streaming_state
#
#         return output_dict
Пример #16
0
    def forward(self, input, **kwargs):
        """
        Inputs Shapes:
            input: batch_size x len_src

        Outputs Shapes:
            out: batch_size x len_src x d_model
            mask_src

        """
        # clean layer history
        self.history.clean()

        # Embedding: batch_size x len_src x d_model
        if self.input_type == "text":
            mask_src = input.data.eq(onmt.constants.PAD).unsqueeze(
                1)  # batch_size x len_src x 1 for broadcasting
            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)
        else:
            mask_src = input.narrow(2, 0, 1).squeeze(2).eq(
                onmt.constants.PAD).unsqueeze(1)
            input = input.narrow(2, 1, input.size(2) - 1)
            emb = self.audio_trans(input.contiguous().view(
                -1, input.size(2))).view(input.size(0), input.size(1), -1)

        # Scale the emb by sqrt(d_model)

        emb = emb * math.sqrt(self.model_size)

        # Adding positional encoding
        emb = self.time_transformer(emb)
        # Dropout
        emb = self.preprocess_layer(emb)

        # B x T x H -> T x B x H
        context = emb.transpose(0, 1).contiguous()

        self.history.push(context)

        for i, layer in enumerate(self.layer_modules):

            context = self.history.pop()

            if len(self.layer_modules
                   ) - i <= onmt.constants.checkpointing and self.training:
                context = checkpoint(custom_layer(layer), context, mask_src)

            else:
                context = layer(context,
                                mask_src)  # batch_size x len_src x d_model

            self.history.push(context)

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        context = self.history.pop()
        context = self.postprocess_layer(context)

        output_dict = {'context': context, 'src_mask': mask_src}

        # return context, mask_src
        return output_dict
Пример #17
0
    def forward(self, input, context, src, atbs=None, **kwargs):
        """
        Inputs Shapes:
            input: (Variable) batch_size x len_tgt (wanna tranpose)
            context: (Variable) batch_size x len_src x d_model
            mask_src (Tensor) batch_size x len_src
        Outputs Shapes:
            out: batch_size x len_tgt x d_model
            coverage: batch_size x len_tgt x len_src

        """
        """ Embedding: batch_size x len_tgt x d_model """
        self.history.clean()

        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)
        if isinstance(emb, tuple):
            emb = emb[0]
        emb = self.preprocess_layer(emb)

        if self.use_feature:
            atb_emb = self.attribute_embeddings(atbs).unsqueeze(1).repeat(
                1, emb.size(1))  #  B x H to 1 x B x H
            emb = torch.cat([emb, atb_emb], dim=-1)
            emb = torch.relu(self.feature_projector(emb))

        if context is not None:
            if self.encoder_type == "audio":
                mask_src = src.data.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
            else:

                mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1)
        else:
            mask_src = None

        if context is not None:
            if self.encoder_type == "audio":
                mask_src = src.data.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
            else:

                mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1)
        else:
            mask_src = None

        len_tgt = input.size(1)
        mask_tgt = input.data.eq(
            onmt.constants.PAD).unsqueeze(1) + self.mask[:len_tgt, :len_tgt]
        mask_tgt = torch.gt(mask_tgt, 0)

        output = emb.transpose(0, 1).contiguous()

        self.history.push(output)

        for i, layer in enumerate(self.layer_modules):

            output = self.history.pop()

            if len(self.layer_modules
                   ) - i <= onmt.constants.checkpointing and self.training:

                output, coverage = checkpoint(custom_layer(layer), output,
                                              context, mask_tgt, mask_src)
                # batch_size x len_src x d_model

            else:
                output, coverage = layer(
                    output, context, mask_tgt,
                    mask_src)  # batch_size x len_src x d_model

            # write into memory
            self.history.push(output)

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        output = self.history.pop()
        output = self.postprocess_layer(output)

        output_dict = {'hidden': output, 'coverage': coverage}

        # return output, None
        return output_dict
Пример #18
0
    def forward(self, input, input_pos=None, **kwargs):
        """
        Inputs Shapes:
            input: batch_size x src_len (wanna tranpose)
        Outputs Shapes:
            out: batch_size x src_len x d_model
            mask_src
        """
        """ Embedding: batch_size x src_len x d_model """
        if self.input_type == "text":
            bsz_first_input = input
            input = input.transpose(0, 1)
            # mask_src = input.eq(onmt.constants.PAD).unsqueeze(1)  # batch_size x src_len x 1 for broadcasting
            mask_src = input.eq(onmt.constants.PAD).unsqueeze(0)
            dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1)

            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)

            # if self.double_position:
            #     assert input_pos is not None
            #     # flatten
            #     src_len, bsz = input_pos.size(0), input_pos.size(1)
            #     input_pos_ = input_pos.contiguous().view(-1).type_as(emb)
            #     abs_pos = self.positional_encoder(input_pos_)
            #     abs_pos = abs_pos.squeeze(1).view(src_len, bsz, -1)
            #
            # else:
            #     abs_pos = None
        else:
            if not self.cnn_downsampling:
                mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq(
                    onmt.constants.PAD).unsqueeze(0)
                dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
                input = input.narrow(2, 1, input.size(2) - 1)
                emb = self.audio_trans(input.contiguous().view(
                    -1, input.size(2))).view(input.size(0), input.size(1), -1)
            else:
                long_mask = input.narrow(2, 0,
                                         1).squeeze(2).eq(onmt.constants.PAD)
                input = input.narrow(2, 1, input.size(2) - 1)

                # first resizing to fit the CNN format
                input = input.view(input.size(0), input.size(1), -1,
                                   self.channels)
                input = input.permute(0, 3, 1, 2)

                input = self.audio_trans(input)
                input = input.permute(0, 2, 1, 3).contiguous()
                input = input.view(input.size(0), input.size(1), -1)
                # print(input.size())
                input = self.linear_trans(input)

                mask_src = long_mask[:, 0:input.size(1) *
                                     4:4].transpose().unsqueeze(0)
                dec_attn_mask = long_mask[:,
                                          0:input.size(1) * 4:4].unsqueeze(1)
                # the size seems to be B x T ?
                emb = input

            emb = emb.transpose(0, 1)
            input = input.transpose(0, 1)
            abs_pos = None

        if onmt.constants.torch_version >= 1.2:
            mask_src = mask_src.bool()

        # Scale the emb by sqrt(d_model)
        emb = emb * math.sqrt(self.model_size)

        # if self.double_position and abs_pos is not None:
        #     # adding position encoding
        #     emb = emb + abs_pos
        klen = input.size(0)

        # allocate positions: from L - 1 to -L + 1
        pos = torch.arange(klen - 1, -klen + 1, -1.0, device=emb.device)

        # clamp the positions (all postions from afar are treated equally, maybe?)
        pos = torch.clamp(pos, -self.max_pos_length, self.max_pos_length)

        # L x 1 x H
        pos_emb = self.positional_encoder(pos.unsqueeze(1))

        # Apply dropout to both context and pos_emb
        context = self.preprocess_layer(emb)

        pos_emb = self.preprocess_layer(pos_emb)

        for i, layer in enumerate(self.layer_modules):
            # src_len x batch_size x d_model
            context = layer(context, pos_emb, mask_src)

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        context = self.postprocess_layer(context)

        output_dict = {
            'context': context,
            'src_mask': dec_attn_mask,
            'src': input
        }

        # return context, mask_src
        return output_dict
Пример #19
0
    def forward(self, input, input_pos=None, input_lang=None, **kwargs):
        """
        Inputs Shapes:
            input: batch_size x src_len (wanna tranpose)
        Outputs Shapes:
            out: batch_size x src_len x d_model
            mask_src
        """
        """ Embedding: batch_size x src_len x d_model """
        if self.input_type == "text":
            bsz_first_input = input
            input = input.transpose(0, 1)
            # mask_src = input.eq(onmt.constants.PAD).unsqueeze(1)  # batch_size x src_len x 1 for broadcasting
            dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1)

            mem_len = 0
            mask_src = input.eq(onmt.constants.PAD).unsqueeze(
                0)  # batch_size x src_len x 1 for broadcasting
            mems = None

            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)
            """ Adding language embeddings """
            if self.use_language_embedding:
                assert self.language_embedding is not None
                # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H
                if self.language_embedding_type in ['sum', 'all_sum']:
                    lang_emb = self.language_embedding(input_lang)
                    emb = emb + lang_emb.unsqueeze(0)

        else:
            if not self.cnn_downsampling:
                mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq(
                    onmt.constants.PAD).unsqueeze(0)
                dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
                input = input.narrow(2, 1, input.size(2) - 1)
                emb = self.audio_trans(input.contiguous().view(
                    -1, input.size(2))).view(input.size(0), input.size(1), -1)
                emb = emb.type_as(input)
            else:
                long_mask = input.narrow(2, 0,
                                         1).squeeze(2).eq(onmt.constants.PAD)
                input = input.narrow(2, 1, input.size(2) - 1)

                # first resizing to fit the CNN format
                input = input.view(input.size(0), input.size(1), -1,
                                   self.channels)
                input = input.permute(0, 3, 1, 2)

                input = self.audio_trans(input)
                input = input.permute(0, 2, 1, 3).contiguous()
                input = input.view(input.size(0), input.size(1), -1)
                # print(input.size())
                input = self.linear_trans(input)

                mask_src = long_mask[:, 0:input.size(1) *
                                     4:4].transpose().unsqueeze(0)
                dec_attn_mask = long_mask[:,
                                          0:input.size(1) * 4:4].unsqueeze(1)
                # the size seems to be B x T ?
                emb = input

            emb = emb.transpose(0, 1)
            input = input.transpose(0, 1)
            abs_pos = None
            mem_len = 0
            mems = None

        if self.unidirectional:
            qlen = input.size(0)
            klen = qlen + mem_len
            attn_mask_src = torch.triu(emb.new_ones(qlen, klen),
                                       diagonal=1 + mem_len).byte()[:, :, None]

            # pad_mask = mask_src

            # mask_src = pad_mask + attn_mask_src
            # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0)
            # mask_src = mask_src.gt(0)

            # with right padding, causal mask covers the mask pad
            mask_src = attn_mask_src

        if onmt.constants.torch_version >= 1.2:
            mask_src = mask_src.bool()
        """ Scale the emb by sqrt(d_model) """
        emb = emb * math.sqrt(self.model_size)
        """ positional encoding """
        qlen = input.size(0)
        klen = qlen + mem_len

        # Asynchronous positions: 2K+1 positions instead of K+1
        if self.unidirectional:
            pos = torch.arange(klen - 1,
                               -1,
                               -1.0,
                               device=emb.device,
                               dtype=emb.dtype)
        else:
            pos = torch.arange(klen - 1,
                               -klen,
                               -1.0,
                               device=emb.device,
                               dtype=emb.dtype)

        # pos_emb has size 2T+1 x 1 x H
        pos_emb = self.positional_encoder(
            pos, bsz=input.size(1) if self.fast_self_attn else None)

        # B x T x H -> T x B x H
        context = emb

        # Apply dropout to both context and pos_emb
        context = self.preprocess_layer(context)

        pos_emb = self.preprocess_layer(pos_emb)

        for i, layer in enumerate(self.layer_modules):
            # src_len x batch_size x d_model

            context = layer(context, pos_emb, mask_src)

        # final layer norm
        context = self.postprocess_layer(context)

        output_dict = defaultdict(lambda: None, {
            'context': context,
            'src_mask': dec_attn_mask,
            'src': input
        })

        return output_dict
Пример #20
0
    def forward(self,
                input,
                context,
                src,
                input_pos=None,
                src_lang=None,
                tgt_lang=None,
                streaming=False,
                **kwargs):
        """
                Inputs Shapes:
                    input: (Variable) batch_size x len_tgt (wanna tranpose)
                    context: (Variable) batch_size x src_len x d_model
                    mask_src (Tensor) batch_size x src_len
                Outputs Shapes:
                    out: batch_size x len_tgt x d_model
                    coverage: batch_size x len_tgt x src_len

                """
        """ Embedding: batch_size x len_tgt x d_model """
        input = input.transpose(0, 1)  # T x B
        emb = embedded_dropout(
            self.word_lut,
            input,
            dropout=self.word_dropout if self.training else 0)
        if not self.late_emb_scale:
            emb = emb * math.sqrt(self.model_size)

        mem_len = 0
        mems = None
        extra_context = None

        if self.use_language_embedding:
            lang_emb = self.language_embeddings(tgt_lang)  # B x H or 1 x H
            if self.language_embedding_type == 'sum':
                emb = emb + lang_emb
            elif self.language_embedding_type == 'concat':
                lang_emb = lang_emb.unsqueeze(0).expand_as(emb)
                concat_emb = torch.cat([emb, lang_emb], dim=-1)
                emb = torch.relu(self.projector(concat_emb))
            else:
                raise NotImplementedError

        if context is not None:
            mask_src = src.eq(onmt.constants.PAD).unsqueeze(1)
        else:
            mask_src = None

        qlen = input.size(0)
        klen = qlen + mem_len

        # preparing self-attention mask. The input is left aligned so we do not need to add the pad mask

        dec_attn_mask = torch.triu(emb.new_ones(qlen, klen),
                                   diagonal=1 + mem_len).byte()[:, :, None]
        # pad_mask = input.eq(onmt.constants.PAD).byte()  # L x B
        #
        # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0)
        # dec_attn_mask = dec_attn_mask.gt(0)
        dec_attn_mask = dec_attn_mask.bool()

        if not self.absolute_position_encoding:
            # relative positions
            if not self.learnable_position_encoding:
                pos = torch.arange(klen - 1,
                                   -1,
                                   -1.0,
                                   device=emb.device,
                                   dtype=emb.dtype)
                pos_emb = self.positional_encoder(pos, bsz=input.size(1))
                pos_emb = self.preprocess_layer(pos_emb)
            else:
                range_vec = torch.arange(klen, device=emb.device)
                range_mat = range_vec.unsqueeze(-1).expand(-1, klen).transpose(
                    0, 1)
                distance_mat = range_vec - range_mat.transpose(0, 1)
                distance_mat.clamp_(-self.max_pos_length,
                                    self.max_pos_length).add_(
                                        self.max_pos_length)
                pos_emb = distance_mat
                # pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype).long()
                # pos.clamp_(-self.max_pos_length, self.max_pos_length).add_(self.max_pos_length)
                # pos_emb = pos.unsqueeze(1)
        else:
            # absolute positions
            emb = self.positional_encoder(emb.transpose(0, 1)).transpose(0, 1)
            pos, pos_emb = None, None
            dec_attn_mask = dec_attn_mask.squeeze(-1)

        if self.late_emb_scale:
            emb = emb * math.sqrt(self.model_size)

        output = self.preprocess_layer(emb.contiguous())

        if self.reversible:
            # TODO: add src lang and tgt lang to reversible
            output, coverage = reversible_decoder(
                self.layer_modules, output, pos_emb, context,
                dec_attn_mask.squeeze(-1), mask_src, False,
                None)  # incremental variables
        else:
            for i, layer in enumerate(self.layer_modules):
                output, coverage = layer(output,
                                         context,
                                         pos_emb,
                                         dec_attn_mask,
                                         mask_src,
                                         src_lang=src_lang,
                                         tgt_lang=tgt_lang)
                # if self.checkpointing == 0 or self.training is False:
                #
                #     output, coverage = layer(output, context, pos_emb, dec_attn_mask, mask_src,
                #                                 src_lang=src_lang, tgt_lang=tgt_lang)
                #
                # else:
                #     output, coverage = checkpoint(create_forward_function(layer), output, context, pos_emb,
                #                                      dec_attn_mask,
                #                                      mask_src, src_lang, tgt_lang)

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        output = self.postprocess_layer(output)

        output_dict = {
            'hidden': output,
            'coverage': coverage,
            'context': context
        }
        output_dict = defaultdict(lambda: None, output_dict)

        return output_dict
Пример #21
0
    def forward(self,
                input,
                input_pos=None,
                input_lang=None,
                streaming=False,
                **kwargs):
        """
        Inputs Shapes:
            input: batch_size x src_len (wanna tranpose)
        Outputs Shapes:
            out: batch_size x src_len x d_model
            mask_src
        """
        """ Embedding: batch_size x src_len x d_model """
        if self.input_type == "text":
            mask_src = input.eq(onmt.constants.PAD).unsqueeze(
                1)  # batch_size x 1 x len_src for broadcasting

            # apply switchout
            # if self.switchout > 0 and self.training:
            #     vocab_size = self.word_lut.weight.size(0)
            #     input = switchout(input, vocab_size, self.switchout)

            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)
        else:
            if not self.cnn_downsampling:
                mask_src = input.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
                input = input.narrow(2, 1, input.size(2) - 1)
                emb = self.audio_trans(input.contiguous().view(
                    -1, input.size(2))).view(input.size(0), input.size(1), -1)
                emb = emb.type_as(input)
            else:
                long_mask = input.narrow(2, 0,
                                         1).squeeze(2).eq(onmt.constants.PAD)
                input = input.narrow(2, 1, input.size(2) - 1)

                # first resizing to fit the CNN format
                input = input.view(input.size(0), input.size(1), -1,
                                   self.channels)
                input = input.permute(0, 3, 1, 2)

                input = self.audio_trans(input)
                input = input.permute(0, 2, 1, 3).contiguous()
                input = input.view(input.size(0), input.size(1), -1)
                # print(input.size())
                input = self.linear_trans(input)

                mask_src = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1)
                # the size seems to be B x T ?
                emb = input

        mask_src = mask_src.bool()
        """ Scale the emb by sqrt(d_model) """
        emb = emb * math.sqrt(self.model_size)
        """ Adding language embeddings """
        if self.use_language_embedding:
            assert self.language_embedding is not None

            if self.language_embedding_type in ['sum', 'all_sum']:
                lang_emb = self.language_embedding(input_lang)
                emb = emb + lang_emb.unsqueeze(1)

        time_encoding = self.positional_encoder.get_positional_embeddings(emb)

        # B x T x H -> T x B x H
        context = self.preprocess_layer(emb.transpose(0, 1))

        for i in range(self.max_layers):
            layer_vector = torch.LongTensor([i]).to(emb.device)
            layer_vector = self.layer_embedding(layer_vector).unsqueeze(
                0)  # 1 x 1 x model_size

            context = self.universal_layer(context, time_encoding,
                                           layer_vector, mask_src)

        # last layer norm
        context = self.postprocess_layer(context)

        output_dict = defaultdict(lambda: None, {
            'context': context,
            'src_mask': mask_src,
            'src': input
        })

        if streaming:
            streaming_state.prev_src_mem_size += sum(input_length.tolist())
            streaming_state.prune_source_memory(self.max_memory_size)
            # streaming_state.update_src_mems(hids, qlen)
            output_dict['streaming_state'] = streaming_state

        return output_dict
Пример #22
0
    def forward(self, input, input_lang=None, **kwargs):
        """
        Inputs Shapes:
            input: batch_size x len_src (wanna tranpose)

        Outputs Shapes:
            out: batch_size x len_src x d_model
            mask_src

        """
        """ Embedding: batch_size x len_src x d_model """
        if self.input_type == "text":
            mask_src = input.eq(onmt.constants.PAD).unsqueeze(
                1)  # batch_size x len_src x 1 for broadcasting

            # apply switchout
            # if self.switchout > 0 and self.training:
            #     vocab_size = self.word_lut.weight.size(0)
            #     input = switchout(input, vocab_size, self.switchout)

            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)
        else:
            if not self.cnn_downsampling:
                mask_src = input.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
                input = input.narrow(2, 1, input.size(2) - 1)
                emb = self.audio_trans(input.contiguous().view(
                    -1, input.size(2))).view(input.size(0), input.size(1), -1)
            else:
                long_mask = input.narrow(2, 0,
                                         1).squeeze(2).eq(onmt.constants.PAD)
                input = input.narrow(2, 1, input.size(2) - 1)

                # first resizing to fit the CNN format
                input = input.view(input.size(0), input.size(1), -1,
                                   self.channels)
                input = input.permute(0, 3, 1, 2)

                input = self.audio_trans(input)
                input = input.permute(0, 2, 1, 3).contiguous()
                input = input.view(input.size(0), input.size(1), -1)
                # print(input.size())
                input = self.linear_trans(input)

                mask_src = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1)
                # the size seems to be B x T ?
                emb = input

        if torch_version >= 1.2:
            mask_src = mask_src.bool()
        """ Scale the emb by sqrt(d_model) """
        emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)
        """ Adding language embeddings """
        if self.use_language_embedding:
            assert self.language_embedding is not None

            if self.language_embedding_type in ['sum', 'all_sum']:
                lang_emb = self.language_embedding(input_lang)
                emb = emb + lang_emb.unsqueeze(1)

        # B x T x H -> T x B x H
        context = emb.transpose(0, 1)

        context = self.preprocess_layer(context)

        for i, layer in enumerate(self.layer_modules):

            context = layer(context,
                            mask_src)  # batch_size x len_src x d_model

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        context = self.postprocess_layer(context)

        output_dict = {'context': context, 'src_mask': mask_src}

        # return context, mask_src
        return output_dict
Пример #23
0
    def forward(self, batch, target_mask=None, streaming=False, **kwargs):

        tgt = batch.get('target_input')
        tgt_lang = batch.get('target_lang')

        if streaming:
            streaming_state = kwargs.get('streaming_state', None)
            mems = streaming_state.tgt_mems
        else:
            mems = None

        qlen = tgt.size(0)

        word_emb = embedded_dropout(
            self.tgt_embedding,
            tgt,
            dropout=self.word_dropout if self.training else 0)
        word_emb.mul_(self.model_size**0.5)

        if self.use_language_embedding:
            lang_emb = self.language_embeddings(tgt_lang)  # B x H

            if self.language_embedding_type in ['sum', 'all_sum']:
                word_emb = word_emb + lang_emb
            else:
                raise NotImplementedError

        mlen = mems[0].size(0) if mems is not None else 0

        # total length: memory + current input
        klen = mlen + qlen

        # all units having the same attention range
        if self.same_length:
            all_ones = word_emb.new_ones(qlen, klen)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (
                torch.triu(all_ones, 1 + mlen) +
                torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None]  # -1
        else:
            dec_attn_mask = torch.triu(word_emb.new_ones(qlen, klen),
                                       diagonal=1 + mlen).byte()[:, :, None]

        dec_attn_mask = dec_attn_mask.bool()

        pos = torch.arange(klen - 1,
                           -1,
                           -1.0,
                           device=word_emb.device,
                           dtype=word_emb.dtype)
        if self.clamp_len > 0:
            pos_seq.clamp_(max=self.clamp_len)

        pos_emb = self.positional_encoder(pos)

        # Applying dropout
        output = self.preprocess_layer(word_emb)

        if streaming:
            hids = [output]

        pos_emb = self.preprocess_layer(pos_emb)

        # FORWARD PASS
        coverage = None
        for i, layer in enumerate(self.layer_modules):
            mems_i = None if mems is None else mems[i]
            output, coverage = layer(
                output, None, pos_emb, dec_attn_mask, None,
                mems=mems_i)  # context and context_mask are None
            if streaming:
                hids.append(output)

        # Final normalization
        output = self.postprocess_layer(output)

        output_dict = {
            'hidden': output,
            'coverage': coverage,
            'context': None,
            'src': None,
            'target_mask': target_mask
        }
        output_dict = defaultdict(lambda: None, output_dict)
        # final layer: computing log probabilities
        logprobs = self.generator[0](output_dict)
        output_dict['logprobs'] = logprobs

        if streaming:
            streaming_state.update_tgt_mems(hids, qlen)

            output_dict['streaming_state'] = streaming_state

        return output_dict
Пример #24
0
    def forward_grow(self, input):
        """
        Inputs Shapes: 
            input: batch_size x len_src (wanna tranpose)
        
        Outputs Shapes:
            out: batch_size x len_src x d_model
            mask_src 
            
        """

        with torch.no_grad():
            """ Embedding: batch_size x len_src x d_model """
            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)
            """ Scale the emb by sqrt(d_model) """

            if self.time == 'positional_encoding':
                emb = emb * math.sqrt(self.model_size)
            """ Adding positional encoding """
            emb = self.time_transformer(emb)
            if isinstance(emb, tuple):
                emb = emb[0]
            emb = self.preprocess_layer(emb)

            mask_src = input.data.eq(onmt.constants.PAD).unsqueeze(
                1)  # batch_size x len_src x 1 for broadcasting

            pad_mask = torch.autograd.Variable(
                input.data.ne(onmt.constants.PAD))  # batch_size x len_src
            #~ pad_mask = None

            context = emb.contiguous()

            memory_bank = list()

            for i in range(self.pretrained_point):

                layer = self.layer_modules[i]

                context, norm_input = layer(
                    context, mask_src,
                    pad_mask)  # batch_size x len_src x d_model

                if i > 0:  # don't keep the norm input of the first layer (a.k.a embedding)
                    memory_bank.append(norm_input)

        for i in range(self.layers - self.pretrained_point):

            res_drop_rate = 0.0
            if i == 0:
                res_drop_rate = self.grow_dropout

            layer = self.layer_modules[self.pretrained_point + i]

            context, norm_input = layer(context,
                                        mask_src,
                                        pad_mask,
                                        residual_dropout=res_drop_rate
                                        )  # batch_size x len_src x d_model

            memory_bank.append(norm_input)

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        context = self.postprocess_layer(context)

        # make a huge memory bank on the encoder side
        memory_bank.append(context)

        memory_bank = torch.stack(memory_bank)

        return memory_bank, mask_src
Пример #25
0
    def forward(self, input, input_lang=None, **kwargs):
        """
        Inputs Shapes:
            input: batch_size x len_src (to be transposed)

        Outputs Shapes:
            out: batch_size x len_src x d_model
            mask_src

        """
        """ Embedding: batch_size x len_src x d_model """
        if self.input_type == "text":
            mask_src = input.eq(onmt.constants.PAD).unsqueeze(
                1)  # batch_size x 1 x len_src for broadcasting

            # apply switchout
            # if self.switchout > 0 and self.training:
            #     vocab_size = self.word_lut.weight.size(0)
            #     input = switchout(input, vocab_size, self.switchout)

            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)
        else:
            if not self.cnn_downsampling:
                mask_src = input.narrow(2, 0, 1).squeeze(2).eq(
                    onmt.constants.PAD).unsqueeze(1)
                input = input.narrow(2, 1, input.size(2) - 1)
                emb = self.audio_trans(input.contiguous().view(
                    -1, input.size(2))).view(input.size(0), input.size(1), -1)
                emb = emb.type_as(input)
            else:
                long_mask = input.narrow(2, 0,
                                         1).squeeze(2).eq(onmt.constants.PAD)
                input = input.narrow(2, 1, input.size(2) - 1)

                # first resizing to fit the CNN format
                input = input.view(input.size(0), input.size(1), -1,
                                   self.channels)
                input = input.permute(0, 3, 1, 2)

                input = self.audio_trans(input)
                input = input.permute(0, 2, 1, 3).contiguous()
                input = input.view(input.size(0), input.size(1), -1)
                # print(input.size())
                input = self.linear_trans(input)

                mask_src = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1)
                # the size seems to be B x T ?
                emb = input

        mask_src = mask_src.bool()
        """ Scale the emb by sqrt(d_model) """
        emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)
        """ Adding language embeddings """
        if self.use_language_embedding:
            assert self.language_embedding is not None

            if self.language_embedding_type in ['sum', 'all_sum']:
                lang_emb = self.language_embedding(input_lang)
                emb = emb + lang_emb.unsqueeze(1)

        # B x T x H -> T x B x H
        context = emb.transpose(0, 1)

        context = self.preprocess_layer(context)

        if self.reversible:
            # x_1 and x_2 are the same at first for reversible
            context = torch.cat([context, context], dim=-1)

            context = ReversibleEncoderFunction.apply(context,
                                                      self.layer_modules,
                                                      mask_src)
        else:
            for i, layer in enumerate(self.layer_modules):
                context = layer(context,
                                mask_src)  # batch_size x len_src x d_model

        context = self.postprocess_layer(context)

        output_dict = {'context': context, 'src_mask': mask_src}

        # return context, mask_src
        return output_dict
Пример #26
0
    def forward_grow(self, input, context, src):
        """
        Inputs Shapes: 
            input: (Variable) batch_size x len_tgt (wanna tranpose)
            context: (Variable) batch_size x len_src x d_model
            mask_src (Tensor) batch_size x len_src
        Outputs Shapes:
            out: batch_size x len_tgt x d_model
            coverage: batch_size x len_tgt x len_src
            
        """
        """ Embedding: batch_size x len_tgt x d_model """

        with torch.no_grad():

            emb = embedded_dropout(
                self.word_lut,
                input,
                dropout=self.word_dropout if self.training else 0)
            if self.time == 'positional_encoding':
                emb = emb * math.sqrt(self.model_size)
            """ Adding positional encoding """
            emb = self.time_transformer(emb)
            if isinstance(emb, tuple):
                emb = emb[0]
            emb = self.preprocess_layer(emb)

            mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1)

            pad_mask_src = torch.autograd.Variable(
                src.data.ne(onmt.constants.PAD))

            len_tgt = input.size(1)
            mask_tgt = input.data.eq(onmt.constants.PAD).unsqueeze(
                1) + self.mask[:len_tgt, :len_tgt]
            mask_tgt = torch.gt(mask_tgt, 0)

            output = emb.contiguous()

            pad_mask_tgt = torch.autograd.Variable(
                input.data.ne(onmt.constants.PAD))  # batch_size x len_src
            pad_mask_src = torch.autograd.Variable(1 - mask_src.squeeze(1))

            for i in range(self.pretrained_point):

                layer = self.layer_modules[i]

                output, coverage = layer(
                    output, context[i], mask_tgt, mask_src, pad_mask_tgt,
                    pad_mask_src)  # batch_size x len_src x d_model

        for i in range(self.layers - self.pretrained_point):

            res_drop_rate = 0.0
            if i == 0:
                res_drop_rate = self.grow_dropout

            layer = self.layer_modules[self.pretrained_point + i]
            output, coverage = layer(output,
                                     context[self.pretrained_point + i],
                                     mask_tgt,
                                     mask_src,
                                     pad_mask_tgt,
                                     pad_mask_src,
                                     residual_dropout=res_drop_rate
                                     )  # batch_size x len_src x d_model
        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        output = self.postprocess_layer(output)

        return output, coverage
Пример #27
0
    def forward(self, dec_seq, enc_out, src, tgt_lang=None, hid=None, **kwargs):

        emb = embedded_dropout(self.word_lut, dec_seq, dropout=self.word_dropout if self.training else 0)
        emb = emb * math.sqrt(self.model_size)

        if self.use_language_embedding:
            # print("Using language embedding")
            lang_emb = self.language_embeddings(tgt_lang)  # B x H or 1 x H
            if self.language_embedding_type == 'sum':

                dec_emb = emb + lang_emb.unsqueeze(1)
            elif self.language_embedding_type == 'concat':
                # replace the bos embedding with the language
                bos_emb = lang_emb.expand_as(emb[0])
                emb[0] = bos_emb

                lang_emb = lang_emb.unsqueeze(0).expand_as(emb)
                concat_emb = torch.cat([emb, lang_emb], dim=-1)
                dec_emb = torch.relu(self.projector(concat_emb))
            else:
                raise NotImplementedError
        else:
            dec_emb = emb

        if enc_out is not None:
            if self.encoder_type == "audio":
                if not self.encoder_cnn_downsampling:
                    mask_src = src.data.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD).unsqueeze(1)
                else:
                    long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD)
                    mask_src = long_mask[:, 0: enc_out.size(0) * 4:4].unsqueeze(1)
            else:

                mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1)
        else:
            mask_src = None

        # if dec_seq.size(0) > 1 and dec_seq.size(1) > 1:
        #     lengths = dec_seq.gt(onmt.constants.PAD).sum(-1)
        #     dec_in = pack_padded_sequence(dec_emb, lengths, batch_first=True, enforce_sorted=False)
        #     dec_out, hid = self.lstm(dec_in, hid)
        #     dec_out = pad_packed_sequence(dec_out, batch_first=True)[0]
        # else:
        if self.multilingual_factorized_weights:
            dec_out, hid = self.lstm(dec_emb, hid, indices=tgt_lang)
        else:
            dec_out, hid = self.lstm(dec_emb, hid)

        lt = dec_seq.size(1)
        attn_mask = mask_src.expand(-1, lt, -1) if not self.fast_xattention else mask_src.squeeze(1)
        # dec_out = self.postprocess_layer(dec_out)
        dec_out = self.preprocess_attn(dec_out)
        dec_out = dec_out.transpose(0, 1).contiguous()
        enc_out = enc_out.contiguous()

        if self.multilingual_factorized_weights:
            output, coverage = self.multihead_tgt(dec_out, enc_out, enc_out, tgt_lang, tgt_lang, attn_mask)
        else:
            output, coverage = self.multihead_tgt(dec_out, enc_out, enc_out, attn_mask)

        output = (output + dec_out)
        output = self.postprocess_layer(output)

        output_dict = defaultdict(lambda: None, {'hidden': output, 'coverage': coverage, 'context': enc_out})
        return output_dict
Пример #28
0
    def forward(self, input, input_pos=None, input_lang=None, streaming=False, **kwargs):
        """
        Inputs Shapes:
            input: batch_size x src_len (wanna tranpose)
        Outputs Shapes:
            out: batch_size x src_len x d_model
            mask_src
        """

        """ Embedding: batch_size x src_len x d_model """
        bsz_first_input = input
        input = input.transpose(0, 1)

        dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1)

        mem_len = 0
        mask_src = input.eq(onmt.constants.PAD).unsqueeze(0)  # batch_size x src_len x 1 for broadcasting
        mems = None

        emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)

        """ Adding language embeddings """
        if self.use_language_embedding:
            assert self.language_embedding is not None
            # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H
            if self.language_embedding_type in ['sum', 'all_sum']:
                lang_emb = self.language_embedding(input_lang)
                # print(lang_emb.size(), emb.size())
                emb = emb + lang_emb.unsqueeze(0)

        if self.unidirectional:
            qlen = input.size(0)
            klen = qlen + mem_len
            attn_mask_src = torch.triu(
                emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None]

            pad_mask = mask_src

            mask_src = pad_mask + attn_mask_src
            # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0)
            mask_src = mask_src.gt(0)

        if onmt.constants.torch_version >= 1.2:
            mask_src = mask_src.bool()

        """ Scale the emb by sqrt(d_model) """
        emb = emb * math.sqrt(self.model_size)

        """ Adding positional encoding """
        qlen = input.size(0)
        klen = qlen + mem_len

        # Asynchronous positions: 2K+1 positions instead of K+1
        if self.unidirectional:
            pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype)
        else:
            pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype)

        # pos_emb has size 2T+1 x 1 x H
        pos_emb = self.positional_encoder(pos, bsz=input.size(1))

        if self.learnable_position_encoding:
            raise NotImplementedError

        # B x T x H -> T x B x H
        context = emb

        # Apply dropout to both context and pos_emb
        context = self.preprocess_layer(context)

        pos_emb = self.preprocess_layer(pos_emb)

        for i, layer in enumerate(self.layer_modules):
            # src_len x batch_size x d_model
            context = layer(context, pos_emb, mask_src, src_lang=input_lang)

        # final layer norm
        context = self.postprocess_layer(context)

        output_dict = defaultdict(lambda: None, {'context': context, 'src_mask': dec_attn_mask, 'src': input})

        return output_dict
Пример #29
0
    def forward(self, input, context, src, input_pos=None, src_lang=None, tgt_lang=None,
                streaming=False, **kwargs):
        """
                Inputs Shapes:
                    input: (Variable) batch_size x len_tgt (wanna tranpose)
                    context: (Variable) batch_size x src_len x d_model
                    mask_src (Tensor) batch_size x src_len
                Outputs Shapes:
                    out: batch_size x len_tgt x d_model
                    coverage: batch_size x len_tgt x src_len

                """

        """ Embedding: batch_size x len_tgt x d_model """
        input = input.transpose(0, 1)  # T x B
        emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)
        emb = emb * math.sqrt(self.model_size)

        mem_len = 0
        mems = None
        extra_context = None

        if self.use_language_embedding:
            lang_emb = self.language_embeddings(tgt_lang)  # B x H or 1 x H
            if self.language_embedding_type == 'sum':
                emb = emb + lang_emb
            elif self.language_embedding_type == 'concat':
                # replace the bos embedding with the language
                bos_emb = lang_emb.expand_as(emb[0])
                emb[0] = bos_emb

                lang_emb = lang_emb.unsqueeze(0).expand_as(emb)
                concat_emb = torch.cat([emb, lang_emb], dim=-1)
                emb = torch.relu(self.projector(concat_emb))
            else:
                raise NotImplementedError

        if context is not None:
            mask_src = src.eq(onmt.constants.PAD).unsqueeze(1)
        else:
            mask_src = None

        qlen = input.size(0)
        klen = qlen + mem_len

        # preparing self-attention mask. The input is either left or right aligned

        dec_attn_mask = torch.triu(
            emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None]
        # pad_mask = input.eq(onmt.constants.PAD).byte()  # L x B
        #
        # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0)
        # dec_attn_mask = dec_attn_mask.gt(0)
        dec_attn_mask = dec_attn_mask.bool()

        pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype)

        pos_emb = self.positional_encoder(pos, bsz=input.size(1))
        output = self.preprocess_layer(emb.contiguous())
        pos_emb = self.preprocess_layer(pos_emb)

        for i, layer in enumerate(self.layer_modules):
            output, coverage, _ = layer(output, context, pos_emb, dec_attn_mask, mask_src,
                                        src_lang=src_lang, tgt_lang=tgt_lang)

        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        output = self.postprocess_layer(output)

        output_dict = {'hidden': output, 'coverage': coverage, 'context': context}
        output_dict = defaultdict(lambda: None, output_dict)

        return output_dict
    def forward(self, batch, target_mask=None, **kwargs):

        src = batch.get('source').transpose(
            0, 1)  # src_len x batch_size -> bsz x src_len
        tgt = batch.get('target_input').transpose(
            0, 1)  # len_tgt x batch_size -> bsz x tgt_len
        src_pos = batch.get('source_pos')
        tgt_pos = batch.get('target_pos')
        src_lang = batch.get('source_lang')
        tgt_lang = batch.get('target_lang')

        tgt_len = tgt.size(1)
        src_len = src.size(1)
        bsz = tgt.size(0)

        # Embedding stage (and scale the embedding)
        src_emb = embedded_dropout(self.src_embedding, src, dropout=self.word_dropout if self.training else 0) \
                  * math.sqrt(self.model_size)
        tgt_emb = embedded_dropout(self.tgt_embedding, tgt, dropout=self.word_dropout if self.training else 0) \
                  * math.sqrt(self.model_size)

        # Add position encoding
        src_emb = self.time_transformer(src_emb)
        tgt_emb = self.time_transformer(tgt_emb)

        if self.use_language_embedding:
            if self.language_embedding_type in ["sum", "all_sum"]:
                src_lang_emb = self.language_embeddings(src_lang)
                src_emb += src_lang_emb.unsqueeze(1)
                tgt_lang_emb = self.language_embeddings(tgt_lang)
                tgt_emb += tgt_lang_emb.unsqueeze(1)

        # concatenate embedding
        emb = torch.cat([src_emb, tgt_emb], dim=1)  # L x batch_size x H

        # prepare self-attention mask
        # For the source: we have two different parts
        # [1 x src_len x batch_size]
        # mask_src_src = src.eq(onmt.constants.PAD).unsqueeze(0).byte()
        # src_pad_mask = mask_src_src
        # # Attention from src to target: everything is padded
        # mask_src_tgt = mask_src_src.new_ones(1, 1, 1).expand(src_len, tgt_len, bsz)
        # # [src_len x L x batch_size]
        # mask_src = torch.cat([mask_src_src.expand(src_len, src_len, bsz), mask_src_tgt], dim=1)
        # mask_src = mask_src.bool()
        # mask_src_src = src.eq(onmt.constants.PAD).unsqueeze(1).byte()  # B x 1 x src_len
        # mask_src_tgt = mask_src_src.new_ones(bsz, src_len, tgt_len)  # bsz x src_len x tgt_len
        #
        # mask_src = torch.cat([mask_src_src.expand(bsz, src_len, src_len), mask_src_tgt], dim=-1)
        #
        # # For the target:
        # mask_tgt_tgt = tgt.eq(onmt.constants.PAD).byte().unsqueeze(1) + self.mask[:tgt_len, :tgt_len]
        # mask_tgt_tgt = torch.gt(mask_tgt_tgt, 0).byte()  # bsz x tgt_len x tgt_len
        #
        # mask_tgt_src = mask_tgt_tgt.new_zeros(bsz, tgt_len, src_len) + src.eq(onmt.constants.PAD).unsqueeze(1).byte()
        # mask_tgt = torch.cat([mask_tgt_src, mask_tgt_tgt], dim=-1)  # bsz x tgt_len x T
        #
        # attn_mask = torch.cat([mask_src, mask_tgt], dim=1).bool()     # L x L x batch_size

        # lets try to use language modeling style
        # input_seq = torch.cat([src, tgt], dim=-1)
        # seq_len = input_seq.size(1)
        #
        # attn_mask = self.mask[:seq_len, :seq_len] + input_seq.eq(onmt.constants.PAD).byte().unsqueeze(1)
        # attn_mask = torch.gt(attn_mask, 0).bool()
        attn_mask = self.gen_mask(src, tgt)

        output = emb

        # Applying dropout and tranpose to T x B x H
        output = self.preprocess_layer(output).transpose(0, 1)

        # FORWARD PASS
        coverage = None
        for i, layer in enumerate(self.layer_modules):
            output, coverage = layer(output, None, attn_mask,
                                     None)  # context and context_mask are None

        # Final normalization
        output = self.postprocess_layer(output)

        # extract the "source" and "target" parts of the output
        context = output[:src_len, :, :]
        output = output[-tgt_len:, :, :]
        output_dict = {
            'hidden': output,
            'coverage': coverage,
            'context': context,
            'src': src,
            'target_mask': target_mask
        }

        # final layer: computing log probabilities
        logprobs = self.generator[0](output_dict)
        output_dict['logprobs'] = logprobs

        return output_dict