예제 #1
0
    def inference(self, preds, memory, memory_mask=None, cache=None):

        assert preds.dim() == 2
        dec_output = self.embedding(preds)
        dec_output, pos = self.pos_encoding.inference(dec_output)
        mask = get_transformer_decoder_mask(preds)

        new_caches = []
        attn_weights = {}
        for i, block in enumerate(self.blocks):
            block_cache = cache[i] if cache is not None else {
                'slf': None,
                'src': None
            }
            dec_output, attn_weight, block_cache = block.inference(
                dec_output,
                mask,
                memory,
                memory_mask.unsqueeze(1),
                pos,
                cache=block_cache)
            attn_weights['dec_block_%d' % i] = attn_weight
            new_caches.append(block_cache)

        if self.normalize_before:
            dec_output = self.after_norm(dec_output)

        logits = self.output_layer(
            dec_output)  # logits [batch_size, 1, model_size]

        log_probs = F.log_softmax(logits[:, -1, :],
                                  dim=-1)  # logits [batch_size, 1, model_size]

        return log_probs, new_caches, attn_weights
예제 #2
0
    def forward(self, targets, memory, memory_mask):

        dec_output = self.embedding(targets)
        if self.relative_positional:
            # [1, 2T - 1]
            position = torch.arange(-(dec_output.size(1) - 1),
                                    dec_output.size(1),
                                    device=dec_output.device).reshape(1, -1)
            pos = self.pos_emb._embedding_from_positions(position)
        else:
            dec_output, pos = self.pos_emb(dec_output)

        dec_mask = get_transformer_decoder_mask(targets)

        attn_weights = {}
        for i, block in enumerate(self.blocks):
            dec_output, attn_weight = block(dec_output, dec_mask, memory,
                                            memory_mask.unsqueeze(1), pos)
            attn_weights['dec_block_%d' % i] = attn_weight

        if self.normalize_before:
            dec_output = self.after_norm(dec_output)

        logits = self.output_layer(dec_output)

        return logits, attn_weights
예제 #3
0
    def forward(self, targets, memory, memory_mask):

        dec_output = self.embedding(targets)
        dec_output, pos = self.pos_encoding(dec_output)

        dec_mask = get_transformer_decoder_mask(targets)

        attn_weights = {}
        for i, block in enumerate(self.blocks):
            dec_output, attn_weight = block(dec_output, dec_mask, memory,
                                            memory_mask.unsqueeze(1), pos)
            attn_weights['dec_block_%d' % i] = attn_weight

        if self.normalize_before:
            dec_output = self.after_norm(dec_output)

        logits = self.output_layer(dec_output)

        return logits, attn_weights