def forward(self, trg_embed: Tensor = None, encoder_output=None, src_mask: Tensor = None, trg_mask: Tensor = None, **kwargs): """ Transformer decoder forward pass. :param trg_embed: embedded targets :param encoder_output: source representations :param src_mask: :param trg_mask: to mask out target paddings Note that a subsequent mask is applied here. :return: context attn probs: batch x layer x head x tgt x src """ assert trg_mask is not None, "trg_mask is required for Transformer" x = self.pe(trg_embed) # add position encoding to word embedding x = self.emb_dropout(x) trg_mask = trg_mask & subsequent_mask( trg_embed.size(1)).type_as(trg_mask) for layer in self.layers: x = layer(x=x, memory=encoder_output.states, src_mask=src_mask, trg_mask=trg_mask) x = self.layer_norm(x) output = self.output_layer(x) return output, x, None, None
def forward( self, trg_embed: Tensor = None, encoder_outputs: dict = None, # note the change src_mask: Tensor = None, trg_mask: Tensor = None, inflection_mask: Tensor = None, **kwargs): """ Transformer decoder forward pass. :param trg_embed: embedded targets :param encoder_output: source representations :param src_mask: :param trg_mask: to mask out target paddings Note that a subsequent mask is applied here. :return: context attn probs: batch x layer x head x tgt x src """ assert trg_mask is not None, "trg_mask is required for Transformer" # encoder_outputs is not an ideal name for this dict because it also # contains encoder_hidden, which transformers do not use enc_outputs = {k: v[0] for k, v in encoder_outputs.items()} x = self.pe(trg_embed) # add position encoding to word embedding x = self.emb_dropout(x) trg_mask = trg_mask & subsequent_mask( trg_embed.size(1)).type_as(trg_mask) self_attn_layers = [] context_attn_layers = defaultdict(list) for layer in self.layers: # todo: multiple (variable?) numbers of context attentions # return a dictionary of context attentions x, self_attn, ctx_attn = layer(x=x, memories=enc_outputs, src_mask=src_mask, trg_mask=trg_mask, inflection_mask=inflection_mask) self_attn_layers.append(self_attn) for enc_name, enc_ctx_attn in ctx_attn.items(): context_attn_layers[enc_name].append(enc_ctx_attn) # todo: inflection_trg attention trg_trg_attn = torch.stack(self_attn_layers, dim=1) attn = {"trg_trg": trg_trg_attn} for enc_name, enc_ctx_attn_layers in context_attn_layers.items(): enc_stacked_attn = torch.stack(enc_ctx_attn_layers, dim=1) attn[enc_name + "_trg"] = enc_stacked_attn x = self.layer_norm(x) output = self.output_layers["vocab"](x) return output, x, attn, None
def forward(self, trg_embed: Tensor = None, encoder_output: Tensor = None, encoder_hidden: Tensor = None, src_mask: Tensor = None, unroll_steps: int = None, hidden: Tensor = None, trg_mask: Tensor = None, **kwargs): """ Transformer decoder forward pass. :param trg_embed: embedded targets :param encoder_output: source representations :param encoder_hidden: unused :param src_mask: :param unroll_steps: unused :param hidden: unused :param trg_mask: to mask out target paddings Note that a subsequent mask is applied here. :param kwargs: :return: """ assert trg_mask is not None, "trg_mask required for Transformer" x = self.pe(trg_embed) # add position encoding to word embedding x = self.emb_dropout(x) trg_mask = trg_mask & subsequent_mask( trg_embed.size(1)).type_as(trg_mask) for layer in self.layers: x = layer(x=x, memory=encoder_output, src_mask=src_mask, trg_mask=trg_mask) x = self.layer_norm(x) # input(('OUTPUT DIM BEFORE', x.shape)) output = self.output_layer(x) # input(('OUTPUT DIM AFTER', output.shape)) # output, hidden, att_scores, att_vectors # <--- last three dont matter for transformer return output, x, None, None
def forward(self, trg_embed: Tensor = None, encoder_output: Tensor = None, encoder_hidden: Tensor = None, src_mask: Tensor = None, unroll_steps: int = None, hidden: Tensor = None, trg_mask: Tensor = None, **kwargs): """ Transformer decoder forward pass. :param trg_embed: embedded targets :param encoder_output: source representations :param encoder_hidden: unused :param src_mask: :param unroll_steps: unused :param hidden: unused :param trg_mask: to mask out target paddings Note that a subsequent mask is applied here. :param kwargs: :return: """ assert trg_mask is not None, "trg_mask required for Transformer" x = self.pe(trg_embed) # add position encoding to word embedding x = self.emb_dropout(x) trg_mask = trg_mask & subsequent_mask( trg_embed.size(1)).type_as(trg_mask) inbetween_layers = [] for layer in self.layers: x = layer(x=x, memory=encoder_output, src_mask=src_mask, trg_mask=trg_mask) inbetween_layers.append(x) x = self.layer_norm(x) output = self.output_layer(x) #output = self.final_layer(output) return output, x, None, None
def _forward_pre_output(self, trg_embed: Tensor = None, encoder_output: Tensor = None, src_mask: Tensor = None, trg_mask: Tensor = None, **kwargs): """ Transformer decoder forward pass. :param trg_embed: embedded targets :param encoder_output: source representations :param src_mask: :param trg_mask: to mask out target paddings Note that a subsequent mask is applied here. :return: context attn probs: batch x layer x head x tgt x src """ assert trg_mask is not None, "trg_mask is required for Transformer" x = self.pe(trg_embed) # add position encoding to word embedding x = self.emb_dropout(x) trg_mask = trg_mask & subsequent_mask( trg_embed.size(1)).type_as(trg_mask) self_attn_layers = [] context_attn_layers = [] for layer in self.layers: x, self_attn, ctx_attn = layer(x=x, memory=encoder_output, src_mask=src_mask, trg_mask=trg_mask) self_attn_layers.append(self_attn) context_attn_layers.append(ctx_attn) x = self.layer_norm(x) trg_trg_attn = torch.stack(self_attn_layers, dim=1) src_trg_attn = torch.stack(context_attn_layers, dim=1) attn = {"trg_trg": trg_trg_attn, "src_trg": src_trg_attn} return x, attn, None