def forward(
        self,
        src_tokens,
        src_lengths,
        z_src_tokens,
        z_src_lengths,
        prev_output_tokens,
        cls_input: Optional[Tensor] = None,
        return_all_hiddens: bool = True,
        features_only: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
    ):
        """
        Run the forward pass for an encoder-decoder model.

        Copied from the base class, but without ``**kwargs``,
        which are not supported by TorchScript.
        """
        encoder_out = self.encoder(
            src_tokens,
            src_lengths=src_lengths,
            cls_input=cls_input,
            return_all_hiddens=return_all_hiddens,
        )
        x = self.f1(encoder_out.encoder_out, encoder_out.encoder_padding_mask)
        encoder_out = EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_out.encoder_padding_mask,  # B x T
            encoder_embedding=encoder_out.encoder_embedding,  # B x T x C
            encoder_states=encoder_out.encoder_states,  # List[T x B x C]
        )

        z_encoder_out = self.encoder(
            z_src_tokens,
            src_lengths=z_src_lengths,
            cls_input=cls_input,
            return_all_hiddens=return_all_hiddens,
        )
        x = self.f2(z_encoder_out.encoder_out,
                    z_encoder_out.encoder_padding_mask)
        z_encoder_out = EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=z_encoder_out.encoder_padding_mask,  # B x T
            encoder_embedding=z_encoder_out.encoder_embedding,  # B x T x C
            encoder_states=z_encoder_out.encoder_states,  # List[T x B x C]
        )

        decoder_out = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            alignment_layer=alignment_layer,
            alignment_heads=alignment_heads,
            src_lengths=src_lengths,
            return_all_hiddens=return_all_hiddens,
            z_encoder_out=z_encoder_out,
            z_src_lengths=z_src_lengths,
        )
        return decoder_out
Ejemplo n.º 2
0
    def forward(self, src_tokens, src_lengths):
        x, input_lengths = self.subsample(src_tokens, src_lengths)
        x = self.embed_scale * x

        encoder_padding_mask = lengths_to_padding_mask(input_lengths)
        positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
        x += positions
        x = self.dropout_module(x)

        for layer in self.transformer_layers:
            x = layer(x, encoder_padding_mask)

        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,
            encoder_padding_mask=encoder_padding_mask,
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )
Ejemplo n.º 3
0
 def forward(self, src_tokens, src_lengths=None, **kwargs):
     return EncoderOut(
         encoder_out=src_tokens,
         encoder_padding_mask=None,
         encoder_embedding=None,
         encoder_states=None,
     )
Ejemplo n.º 4
0
    def generate(self, models, sample, **unused):
        """Generate a batch of inferences.
        EncoderOut(
            encoder_out=encoder_out['encoder_out'],  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=encoder_out['encoder_padding_mask'],  # B x T
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )
        """
        encoder_output = models[0].get_encoder_output(sample['net_input'])
        encoder_out = {
            "encoder_out":
            encoder_output.encoder_out.transpose(0, 1),  # B x T x C
            "padding_mask": encoder_output.encoder_padding_mask
        }
        alphas, _ = models[0].assigner(encoder_out)
        # _alphas, num_output = self.resize(alphas, kwargs['target_lengths'], at_least_one=True)
        cif_outputs = models[0].cif(encoder_out, alphas)
        src_lengths = torch.round(alphas.sum(-1)).int()
        self.step_forward_fn = models[0].decode
        encoder_output = EncoderOut(
            encoder_out=cif_outputs.transpose(0, 1),  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=~utils.sequence_mask(
                src_lengths, dtype=torch.bool),  # B x T
            encoder_states=None,
            src_tokens=None,
            src_lengths=src_lengths,
        )

        return self.decode(encoder_output)
Ejemplo n.º 5
0
    def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        if self.layer_wise_attention:
            return_all_hiddens = True

        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        encoder_states = [] if return_all_hiddens else None

        position_bias = None
        if self.rel_pos:
            seq_len, bsz, _ = x.size()
            position_bias = self.compute_bias(seq_len, seq_len)  # (1, n_heads, qlen, klen)
            position_bias = position_bias.repeat(bsz, 1, 1, 1)
            position_bias = position_bias.view(bsz * self.num_heads, seq_len, seq_len)

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask, position_bias=position_bias)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=None,
            src_lengths=None,
        )
Ejemplo n.º 6
0
 def reorder_encoder_out(self, encoder_out, new_order):
     return EncoderOut(
         encoder_out=encoder_out.encoder_out.index_select(0, new_order),
         encoder_padding_mask=None,
         encoder_embedding=None,
         encoder_states=None,
     )
Ejemplo n.º 7
0
    def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False, return_all_attn: bool = False):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).
            return_all_attn (bool, optional): also return all of the
                intermediate layers' attention weights (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
                - **encoder_attn** (List[Tensor]): all intermediate
                  layers' attention weights of shape `(num_heads, batch, src_len, src_len)`.
                  Only populated if *return_all_attn* is True.
        """
        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        encoder_states = [] if return_all_hiddens else None
        encoder_attn = [] if return_all_attn else None

        # encoder layers
        for layer in self.layers:
            x, attn = layer(x, encoder_padding_mask, need_head_weights=return_all_attn)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)
            if return_all_attn and attn is not None:
                assert encoder_attn is not None
                encoder_attn.append(attn)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            encoder_attn=encoder_attn,  # List[N x B x T x T]
            src_tokens=None,
            src_lengths=None,
        )
Ejemplo n.º 8
0
    def reorder_encoder_out(self, encoder_out, new_order):
        """
        if self.beam_size < 0:
            self.beam_size = int(new_order.shape[0] / self.batch_size)
        else:
            new_order = new_order // self.beam_size
        new_order = new_order[:: self.beam_size]
        new_encoder_out = encoder_out.encoder_out.index_select(1, new_order)
        new_encoder_padding_mask = encoder_out.encoder_padding_mask.index_select(
            0, new_order
        )
        """
        new_encoder_out = encoder_out.encoder_out.index_select(1, new_order)
        new_encoder_padding_mask = encoder_out.encoder_padding_mask.index_select(
            0, new_order
        )

        return EncoderOut(
            encoder_out=new_encoder_out,  # T x B x C
            encoder_padding_mask=new_encoder_padding_mask,  # B x T
            encoder_embedding=None,  # B x T x C
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )
Ejemplo n.º 9
0
    def forward(self,
                src_tokens,
                cluster_ids,
                src_lengths,
                return_all_hiddens: bool = False):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        # print(src_tokens)
        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            if isinstance(layer, TransformerClusterEncoderLayer):
                x = layer(x, encoder_padding_mask, int(cluster_ids[0]))
            else:
                x = layer(x, encoder_padding_mask)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=None,
            src_lengths=None,
        )
 def forward(self, src_videos, src_lengths=None, **kwargs):
     x = self.cnn(src_videos.transpose(1, 2).contiguous())  # B x C x T
     x = x.transpose(1, 2).contiguous().transpose(0, 1)  # T X B X C
     return EncoderOut(
         encoder_out=x,  # T x B x C
         encoder_padding_mask=None,  # B x T
         encoder_embedding=None,  # B x T x C
         encoder_states=None,  # List[T x B x C]
     )
Ejemplo n.º 11
0
    def reorder_encoder_out(self, encoder_out: EncoderOut, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """

        """
        Since encoder_padding_mask and encoder_embedding are both of type
        Optional[Tensor] in EncoderOut, they need to be copied as local
        variables for Torchscript Optional refinement
        """
        encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask
        encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding

        new_encoder_out = (
            encoder_out.encoder_out
            if encoder_out.encoder_out is None
            else encoder_out.encoder_out.index_select(1, new_order)
        )
        new_encoder_padding_mask = (
            encoder_padding_mask
            if encoder_padding_mask is None
            else encoder_padding_mask.index_select(0, new_order)
        )
        new_encoder_embedding = (
            encoder_embedding
            if encoder_embedding is None
            else encoder_embedding.index_select(0, new_order)
        )
        src_tokens = encoder_out.src_tokens
        if src_tokens is not None:
            src_tokens = src_tokens.index_select(0, new_order)

        src_lengths = encoder_out.src_lengths
        if src_lengths is not None:
            src_lengths = src_lengths.index_select(0, new_order)

        encoder_states = encoder_out.encoder_states
        if encoder_states is not None:
            for idx, state in enumerate(encoder_states):
                encoder_states[idx] = state.index_select(1, new_order)

        return EncoderOut(
            encoder_out=new_encoder_out,  # T x B x C
            encoder_padding_mask=new_encoder_padding_mask,  # B x T
            encoder_embedding=new_encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=src_tokens,  # B x T
            src_lengths=src_lengths,  # B x 1
        )
Ejemplo n.º 12
0
    def forward(self, src_tokens, src_lengths: Tensor, **unused):
        if self.left_pad:
            # nn.utils.rnn.pack_padded_sequence requires right-padding;
            # convert left-padding to right-padding
            src_tokens = speech_utils.convert_padding_direction(
                src_tokens,
                src_lengths,
                left_to_right=True,
            )

        if self.conv_layers_before is not None:
            x, src_lengths, padding_mask = self.conv_layers_before(
                src_tokens, src_lengths)
        else:
            x, padding_mask = src_tokens, \
                ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1))

        bsz, seqlen = x.size(0), x.size(1)

        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        state_size = 2 if self.bidirectional else 1, bsz, self.hidden_size
        h0, c0 = x.new_zeros(*state_size), x.new_zeros(*state_size)

        for i in range(len(self.lstm)):
            if self.residual and i > 0:  # residual connection starts from the 2nd layer
                prev_x = x
            # pack embedded source tokens into a PackedSequence
            packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data)

            # apply LSTM
            packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0))

            # unpack outputs and apply dropout
            x, _ = nn.utils.rnn.pad_packed_sequence(
                packed_outs, padding_value=self.padding_value * 1.0)
            if i < len(
                    self.lstm) - 1:  # not applying dropout for the last layer
                x = F.dropout(x, p=self.dropout_out, training=self.training)
            x = x + prev_x if self.residual and i > 0 else x
        assert list(x.size()) == [seqlen, bsz, self.output_units]

        encoder_padding_mask = padding_mask.t()

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask
            if encoder_padding_mask.any() else None,  # T x B
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=src_lengths,  # B
        )
Ejemplo n.º 13
0
 def get_encoder_output(self, net_input):
     encoder_out = self.encoder(tbc=True, **net_input)
     return EncoderOut(
         encoder_out=encoder_out['encoder_out'],  # T x B x C
         encoder_embedding=None,
         encoder_padding_mask=encoder_out['encoder_padding_mask'],  # B x T
         encoder_states=None,
         src_tokens=None,
         src_lengths=None,
     )
Ejemplo n.º 14
0
 def forward(self, src_tokens, src_lengths, **kwargs):
     d = super().forward(c, src_lengths, **kwargs)
     epm = d.get('encoder_padding_mask', None)
     epm = epm.t() if epm is not None else None
     return EncoderOut(
         encoder_out=d['encoder_out'],  # T x B x C
         encoder_padding_mask=epm,  # B x T
         encoder_embedding=None,  # B x T x C
         encoder_states=None,  # List[T x B x C]
     )
Ejemplo n.º 15
0
 def forward(self, src_tokens, src_lengths=None, **kwargs):
     assert "fancy_other_input" in kwargs
     assert kwargs["fancy_other_input"] is not None
     return EncoderOut(
         encoder_out=src_tokens,
         encoder_padding_mask=None,
         encoder_embedding=None,
         encoder_states=None,
         src_tokens=None,
         src_lengths=None,
     )
Ejemplo n.º 16
0
 def reorder_encoder_out(self, encoder_out: EncoderOut, new_order):
     encoder_padding_mask = encoder_out.encoder_padding_mask.index_select(1, new_order) \
         if encoder_out.encoder_padding_mask is not None else None
     return EncoderOut(
         encoder_out=encoder_out.encoder_out.index_select(1, new_order),
         encoder_padding_mask=encoder_padding_mask,
         encoder_embedding=None,
         encoder_states=None,
         src_tokens=None,
         src_lengths=encoder_out.src_lengths.index_select(0, new_order),
     )
Ejemplo n.º 17
0
    def forward(
        self,
        src_tokens,
        src_lengths,
        return_all_hiddens: bool = False,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        out = super().forward(src_tokens,
                              src_lengths,
                              return_all_hiddens=return_all_hiddens)
        x, x_lengths = out.encoder_out, out.src_lengths

        # determine which output frame to select for loss evaluation/test, assuming
        # all examples in a batch are of the same length for chunk-wise training/test
        if (self.out_chunk_end is not None
                and (self.training or not self.training_stage)):
            x = x[self.out_chunk_begin:
                  self.out_chunk_end]  # T x B x C -> W x B x C
            x_lengths = x_lengths.fill_(x.size(0))

        if self.fc_out is not None:
            x = self.fc_out(x)  # T x B x C -> T x B x V

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=out.encoder_padding_mask.transpose(
                0, 1),  # T x B
            encoder_embedding=out.encoder_embedding,  # None
            encoder_states=out.encoder_states,  # List[T x B x C]
            src_tokens=out.src_tokens,  # None
            src_lengths=x_lengths,  # B
        )
Ejemplo n.º 18
0
    def reorder_encoder_out(self, encoder_out: EncoderOut, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        new_encoder_out: Dict[str, Tensor] = {}

        new_encoder_out["encoder_out"] = (
            encoder_out.encoder_out if encoder_out.encoder_out is None else
            encoder_out.encoder_out.index_select(1, new_order))
        new_encoder_out["encoder_padding_mask"] = (
            encoder_out.encoder_padding_mask
            if encoder_out.encoder_padding_mask is None else
            encoder_out.encoder_padding_mask.index_select(0, new_order))
        new_encoder_out["encoder_embedding"] = (
            encoder_out.encoder_embedding
            if encoder_out.encoder_embedding is None else
            encoder_out.encoder_embedding.index_select(0, new_order))
        src_tokens = encoder_out.src_tokens
        if src_tokens is not None:
            src_tokens = src_tokens.index_select(0, new_order)

        src_lengths = encoder_out.src_lengths
        if src_lengths is not None:
            src_lengths = src_lengths.index_select(0, new_order)

        encoder_states = encoder_out.encoder_states
        if encoder_states is not None:
            for idx, state in enumerate(encoder_states):
                encoder_states[idx] = state.index_select(1, new_order)

        new_encoder_out["bottom_features"] = (
            encoder_out.bottom_features if encoder_out.bottom_features is None
            else encoder_out.bottom_features.index_select(0, new_order))

        return EncoderOut(
            encoder_out=new_encoder_out["encoder_out"],  # T x B x C
            encoder_padding_mask=new_encoder_out[
                "encoder_padding_mask"],  # B x T
            encoder_embedding=new_encoder_out[
                "encoder_embedding"],  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=src_tokens,  # B x T
            src_lengths=src_lengths,  # B x 1
            bottom_features=new_encoder_out["bottom_features"],  # B x T'
        )
Ejemplo n.º 19
0
    def forward(self, src_tokens, src_lengths=None, **kwargs):
        b_sz, t_sz = src_tokens.shape
        padding_needed = t_sz % 2
        x = src_tokens
        if padding_needed > 0:
            padding_needed = 2 - padding_needed
            x = F.pad(x, (0, padding_needed))

        return EncoderOut(
            encoder_out=x.view(b_sz, -1, 2),
            encoder_padding_mask=None,
            encoder_embedding=None,
            encoder_states=None,
        )
Ejemplo n.º 20
0
 def combine_encoder_out(self, outs):
     encoder_out = torch.cat([out[0] for out in outs], 0)
     encoder_padding_mask = torch.cat([out[1] for out in outs], 1)
     encoder_embedding = torch.cat([out[2] for out in outs], 1)
     encoder_states = None
     if all(out[3] is not None for out in outs):
         encoder_states = torch.cat([out[3] for out in outs], 0)
     return EncoderOut(
         encoder_out=encoder_out,  # T x B x C
         encoder_padding_mask=encoder_padding_mask,  # B x T
         encoder_embedding=encoder_embedding,  # B x T x C
         encoder_states=encoder_states,  # List[T x B x C]
         src_tokens=None,
         src_lengths=None)
Ejemplo n.º 21
0
    def reorder_encoder_out(self, encoder_out:EncoderOut, new_order):
        if encoder_out.encoder_padding_mask is not None:
            epm = encoder_out.encoder_padding_mask.index_select(1, new_order)
        else:
            epm = encoder_out.encoder_padding_mask

        return EncoderOut(
            encoder_out=encoder_out.encoder_out.index_select(
            1, new_order
        ),  # T x B x C
            encoder_padding_mask=epm,  # B x T
            encoder_embedding=None,  # B x T x C
            encoder_states=None,  # List[T x B x C]
        )
Ejemplo n.º 22
0
    def forward(self, token_id, mask_label=None, decode_label=None, label=None):#
        # batch_size,can_num,can_legth=candidate_id.shape
        # batch_size,_,his_length=his_id.shape

        #print('???shape: ',token_id.shape,mask_label.shape,decode_label.shape)
        if label is not None:
            return self.predict(token_id,label)


        token_features,_ = self.encoder(token_id)#bsz,length,dim
        token_features=token_features[-1].transpose(0,1)#[:,0,:]
        loss_mask, sample_size_mask = self.predict_mask(token_features, mask_label)


        h=token_features[:,0:,]
        h=EncoderOut(
            encoder_out=h,  # T x B x C
            encoder_padding_mask=None,  # B x T
            encoder_embedding=None,  # B x T x C
            encoder_states=None,  # List[T x B x C]
            src_tokens=None,
            src_lengths=None,
        )

        loss_decode, sample_size_decode =self.predict_decode(h ,decode_label) 



        # loss = F.nll_loss(
        #     F.log_softmax(
        #         res.view(-1, res.size(-1)),
        #         dim=-1,
        #         dtype=torch.float32,
        #     ),
        #     label.view(-1),
        #     reduction='sum',
        #     #ignore_index=self.padding_idx,
        # )

        #loss=0.5*loss_decode+0.5*loss_mask

        # loss=loss_mask
        # sample_size= sample_size_mask

        # loss=loss_decode
        # sample_size= sample_size_decode

        #return loss, sample_size #,torch.tensor(sample_size).cuda()
        return loss_mask,sample_size_mask,loss_decode,sample_size_decode
Ejemplo n.º 23
0
 def score(self, src_tokens, tgt_tokens):
     src_tokens = src_tokens[:, 1:]
     assert src_tokens.shape[0] == 1
     unique_tgt_tokens = tgt_tokens.unique(dim=0)
     x = src_tokens[0].cpu().numpy()
     for i in range(x.shape[0]):
         x[i] = self.src_vmap[x[i]]
     y = unique_tgt_tokens.cpu().numpy()
     for r in range(y.shape[0]):
         pad = False
         for c in range(y.shape[1]):
             if pad:
                 y[r][c] = 1
             else:
                 if y[r][c] == 2:
                     pad = True
                 y[r][c] = self.tgt_vmap[y[r][c]]
     B = unique_tgt_tokens.shape[0]
     with torch.no_grad():
         x_tensor = torch.tensor(x)[None, :].cuda()
         y_tensor = torch.tensor(y).cuda()
         x_lens = torch.tensor([x_tensor.shape[1]]).cuda()
         y_lens = torch.ne(y_tensor, 1).sum(1) - 1
         # Transformer forward >>>
         encoder_out = self.transformer.encoder(x_tensor,
                                                src_lengths=x_lens,
                                                return_all_hiddens=False)
         encoder_out = EncoderOut(
             encoder_out.encoder_out.repeat(1, B, 1),
             encoder_out.encoder_padding_mask.repeat(B, 1),
             encoder_out.encoder_embedding.repeat(B, 1, 1), None, None,
             None)
         decoder_out = self.transformer.decoder(
             y_tensor[:, :-1],
             encoder_out=encoder_out,
             src_lengths=x_lens.repeat(B),
             return_all_hiddens=False,
         )
         logits = decoder_out[0]
         # <<<
         logp = torch.log_softmax(logits, 2)
         _, L, V = logp.shape
         token_logp = logp.view(B * L,
                                V)[torch.arange(B * L),
                                   y_tensor[:, 1:].flatten()].view(B, L)
         y_mask = torch.arange(L).unsqueeze(0).repeat(
             B, 1).cuda() < y_lens[:, None]
         scores = (token_logp * y_mask).sum(1) / y_mask.sum(1)
     return unique_tgt_tokens, scores
Ejemplo n.º 24
0
 def apply_adapter(self, enc_out):
     if self.adapter is None:
         return enc_out
     rst = self.adapter(enc_out.encoder_out)
     if enc_out.encoder_padding_mask is not None:
         rst.masked_fill_(
             enc_out.encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0)
     return EncoderOut(
         encoder_out=rst,
         encoder_padding_mask=enc_out.encoder_padding_mask,
         encoder_embedding=enc_out.encoder_embedding,
         encoder_states=enc_out.encoder_states,
         src_tokens=enc_out.src_tokens,
         src_lengths=enc_out.src_lengths,
     )
Ejemplo n.º 25
0
    def forward(
        self,
        src_tokens,
        src_lengths,
        return_all_hiddens: bool = False,
        token_embeddings: Optional[torch.Tensor] = None,
    ):
        # B x C x H x W
        x = src_tokens
        # TODO: compute padding mask from lengths
        # see causal ST encoder (w/ subsampler) for reference of how to get mask
        # from fairseq.data.data_utils import lengths_to_padding_mask

        encoder_padding_mask = None
        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for block in self.vggblocks:
            x, extra1 = block(x, return_all_hiddens=return_all_hiddens)
            if return_all_hiddens:
                encoder_states.extend(extra1)

        # B x C x H x W -> B x (HxW) x C
        _b, _c, _h, _w = x.size()
        x = x.view(_b, _c, _h * _w).permute(0, 2, 1)

        if self.embed_positions is not None:
            pos = x.new_ones((_b, _h * _w))  # B x T
            if pos.size(-1) > self.max_positions():
                pdb.set_trace()
                raise ValueError(
                    "tokens exceeds maximum length: {} > {}".format(
                        pos.size(-1), self.max_positions()))
            x = x + self.embed_positions(pos)  # input: B x T

        # B x (HxW) x C -> (HxW) x B x C
        x = x.permute(1, 0, 2)
        x = self.dropout_module(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=None,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=None,
            src_lengths=None,
        )
Ejemplo n.º 26
0
 def reorder_encoder_out(self, encoder_out: EncoderOut, new_order):
     encoder_padding_mask: Optional[
         Tensor] = encoder_out.encoder_padding_mask
     src_lengths: Optional[Tensor] = encoder_out.src_lengths
     new_encoder_padding_mask = (
         encoder_padding_mask if encoder_padding_mask is None else
         encoder_padding_mask.index_select(1, new_order))
     new_src_lengths = (src_lengths if src_lengths is None else
                        src_lengths.index_select(0, new_order))
     return EncoderOut(
         encoder_out=encoder_out.encoder_out.index_select(1, new_order),
         encoder_padding_mask=new_encoder_padding_mask,
         encoder_embedding=None,
         encoder_states=None,
         src_tokens=None,
         src_lengths=new_src_lengths,
     )
Ejemplo n.º 27
0
 def forward(self, src_tokens, src_lengths, **kwargs):
     self.wav2vec_model.eval()
     with torch.no_grad():
         z = self.wav2vec_model.feature_extractor(src_tokens.squeeze())
         c = self.wav2vec_model.feature_aggregator(z).permute(0,2,1)
     subsample_factor = src_tokens.shape[1]/c.shape[1]
     src_lengths = torch.ceil(src_lengths /subsample_factor).type(torch.int64)
     src_lengths = torch.min(torch.tensor(c.shape[1]).to(src_lengths.device),src_lengths)
     d =  super().forward(c, src_lengths, **kwargs)
     epm = d.get('encoder_padding_mask', None)
     epm = epm.t() if epm is not None else None
     return EncoderOut(
         encoder_out=d['encoder_out'],  # T x B x C
         encoder_padding_mask=epm,  # B x T
         encoder_embedding=None,  # B x T x C
         encoder_states=None,  # List[T x B x C]
     )
Ejemplo n.º 28
0
    def forward(self, tgt, enc_out, src_len):
        assert self.is_initialized

        if self.impl == "fairseq":

            B, L, H = enc_out.shape

            encoder_out = EncoderOut(
                enc_out.transpose(0, 1),
                torch.arange(L, device=src_len.device).unsqueeze(0).expand(
                    (B, L)) - src_len.unsqueeze(1) >= 1, None, None, None,
                None)
            output, _ = self.model.forward(tgt,
                                           encoder_out=encoder_out,
                                           src_lengths=src_len)

        return output
    def forward(
        self,
        src_tokens,
        src_lengths,
        cls_input: Optional[Tensor] = None,
        return_all_hiddens: bool = False,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        x, encoder_embedding = self.forward_embedding(src_tokens)
        x = x.transpose(0, 1)  # B x T x C -> T x B x C

        # U-Net part:
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        x = self.forward_unet(x, encoder_padding_mask)

        # if not return_all hiddens, encoder states are expected to be an empty list
        encoder_states = []

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=src_tokens,
            src_lengths=src_lengths,
        )
Ejemplo n.º 30
0
    def forward(
        self,
        src_tokens,
        src_lengths: Tensor,
        enforce_sorted: bool = True,
        **unused,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of
                shape `(batch, src_len)`
            src_lengths (LongTensor): lengths of each source sentence of
                shape `(batch)`
            enforce_sorted (bool, optional): if True, `src_tokens` is
                expected to contain sequences sorted by length in a
                decreasing order. If False, this condition is not
                required. Default: True.
        """
        out = super().forward(src_tokens,
                              src_lengths,
                              enforce_sorted=enforce_sorted,
                              **unused)
        x, encoder_padding_mask, x_lengths = out.encoder_out, out.encoder_padding_mask, out.src_lengths

        # determine which output frame to select for loss evaluation/test, assuming
        # all examples in a batch are of the same length for chunk-wise training/test
        if (self.out_chunk_end is not None
                and (self.training or not self.training_stage)):
            x = x[self.out_chunk_begin:
                  self.out_chunk_end]  # T x B x C -> W x B x C
            x_lengths = x_lengths.fill_(x.size(0))
            assert encoder_padding_mask is None

        if self.fc_out is not None:
            x = self.fc_out(x)  # T x B x C -> T x B x V

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask
            if encoder_padding_mask.any() else None,  # T x B
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=x_lengths,  # B
        )