Пример #1
0
    def forward(self, encoder_outputs, encoder_output_lengths, decoder_inputs,
                decoder_input_lengths):

        B, T_e, D_e = encoder_outputs.shape
        encoder_outputs = encoder_outputs.permute(1, 0, 2)  # [S, B, D_e]

        _, T_d = decoder_inputs.shape

        memory_key_padding_mask = utils.get_transformer_padding_byte_masks(
            B, T_e, encoder_output_lengths).to(encoder_outputs.device)
        tgt_key_padding_mask = utils.get_transformer_padding_byte_masks(
            B, T_d, decoder_input_lengths).to(encoder_outputs.device)
        casual_masks = utils.get_transformer_casual_masks(T_d).to(
            encoder_outputs.device)

        outputs = self.emb(decoder_inputs) * self.emb_scale
        outputs = self.pe(outputs)
        outputs = self.dropout(outputs)
        outputs = outputs.permute(1, 0, 2)

        outputs = self.transformer_block(
            outputs,
            encoder_outputs,
            memory_mask=None,
            memory_key_padding_mask=memory_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            tgt_mask=casual_masks)
        outputs = outputs.permute(1, 0, 2)
        outputs = self.output_affine(outputs)

        return outputs
Пример #2
0
    def step_forward(self, encoded, len_encoded, decoder_inputs):
        device = encoded.device
        B, T, D = encoded.shape
        _, t = decoder_inputs.shape

        decoder_inputs_pad = F.pad(decoder_inputs, (0, T - t))

        encoded = encoded.permute(1, 0, 2)  # [S, B, D_e]

        src_key_padding_mask = utils.get_transformer_padding_byte_masks(
            B, T, len_encoded).to(device)
        casual_masks = utils.get_transformer_casual_masks(T).to(device)

        decoder_inputs_emb = self.emb(decoder_inputs_pad) * self.emb_scale
        decoder_inputs_emb = self.pe(decoder_inputs_emb)
        decoder_inputs_emb = self.dropout(decoder_inputs_emb)
        decoder_inputs_emb = decoder_inputs_emb.permute(1, 0, 2)

        outputs = self.input_affine(
            torch.cat([encoded, decoder_inputs_emb], -1))

        outputs = self.transformer_block(
            outputs,
            src_key_padding_mask=src_key_padding_mask,
            mask=casual_masks)

        outputs = torch.cat([encoded, outputs], -1)

        outputs = outputs.permute(1, 0, 2)
        outputs = self.output_affine(outputs[:, t - 1, :])

        return outputs
Пример #3
0
    def forward(self, encoder_outputs, decoder_inputs, decoder_input_lengths):
        device = encoder_outputs.device
        B, T, D = encoder_outputs.shape
        encoder_output_lengths = decoder_input_lengths

        encoder_outputs = encoder_outputs.permute(1, 0, 2)  # [S, B, D_e]

        src_key_padding_mask = utils.get_transformer_padding_byte_masks(
            B, T, encoder_output_lengths).to(device)
        casual_masks = utils.get_transformer_casual_masks(T).to(device)

        decoder_inputs_emb = self.emb(decoder_inputs) * self.emb_scale
        decoder_inputs_emb = self.pe(decoder_inputs_emb)
        decoder_inputs_emb = self.dropout(decoder_inputs_emb)
        decoder_inputs_emb = decoder_inputs_emb.permute(1, 0, 2)

        outputs = self.input_affine(
            torch.cat([encoder_outputs, decoder_inputs_emb], -1))

        outputs = self.transformer_block(
            outputs,
            src_key_padding_mask=src_key_padding_mask,
            mask=casual_masks)

        outputs = torch.cat([encoder_outputs, outputs], -1)

        outputs = outputs.permute(1, 0, 2)
        outputs = self.output_affine(outputs)

        return outputs
Пример #4
0
def test_get_transformer_padding_byte_masks():
    B = 3
    T = 5
    lengths = torch.tensor([3, 4, 5]).long()
    masks = utils.get_transformer_padding_byte_masks(B, T, lengths)
    print('test_get_transformer_padding_byte_masks')
    print(masks)
Пример #5
0
    def forward(self, feats, feat_lengths):
        if self.subconf:
            outputs, output_lengths = self.sub(feats, feat_lengths)
        else:
            outputs, output_lengths = self.affine(feats), feat_lengths

        outputs = self.dropout(self.pe(outputs))

        B, T, D_o = outputs.shape
        src_key_padding_mask = utils.get_transformer_padding_byte_masks(
            B, T, output_lengths).to(outputs.device)
        outputs = outputs.permute(1, 0, 2)

        outputs = self.transformer_encoder(
            outputs, src_key_padding_mask=src_key_padding_mask)
        outputs = outputs.permute(1, 0, 2)

        return outputs, output_lengths
Пример #6
0
    def forward(self, feats, feat_lengths, return_atten=False):
        outputs, output_lengths = self.sub(feats, feat_lengths)
        outputs = self.dropout(self.pe(outputs))

        B, T, D_o = outputs.shape
        src_key_padding_mask = utils.get_transformer_padding_byte_masks(B, T, output_lengths).to(outputs.device)
        outputs = outputs.permute(1, 0, 2)
        if return_atten:
            outputs, self_atten_list = self.transformer_encoder(outputs, 
                    src_key_padding_mask=src_key_padding_mask, 
                    return_atten=True)
        else:
            outputs = self.transformer_encoder(outputs, 
                    src_key_padding_mask=src_key_padding_mask, 
                    return_atten=False)
        outputs = outputs.permute(1, 0, 2)
        if return_atten:
            return outputs, output_lengths, self_atten_list
        return outputs, output_lengths
Пример #7
0
    def forward(self, ids, lengths, return_atten=False):
        B, T = ids.shape

        key_padding_mask = utils.get_transformer_padding_byte_masks(
            B, T, lengths).to(ids.device)
        casual_masks = utils.get_transformer_casual_masks(T).to(ids.device)

        outputs = self.emb(ids) * self.scale
        outputs = self.pe(outputs)
        outputs = self.dropout(outputs)
        outputs = outputs.permute(1, 0, 2)

        outputs, self_atten_list = self.transformer_encoder(
            outputs,
            mask=casual_masks,
            src_key_padding_mask=key_padding_mask,
            return_atten=True)
        outputs = self.output_affine(outputs)
        if return_atten:
            return outputs, self_atten_list
        return outputs