class PrimeTransformerEncoder(TransformerEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(args, dictionary=dictionary, embed_tokens=embed_tokens) if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args, i) for i in range(args.encoder_layers) ]) def build_encoder_layer(self, args, layer_id=0): layer = PrimeTransformerEncoderLayer(args, layer_id=layer_id) if getattr(args, "checkpoint_activations", False): layer = checkpoint_wrapper(layer) return layer
class PrimeTransformerDecoder(TransformerDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). final_norm (bool, optional): apply layer norm to the output of the final decoder layer (default: True). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.args = args super().__init__(args, dictionary=dictionary, embed_tokens=embed_tokens, no_encoder_attn=no_encoder_attn) if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_decoder_layer(args, no_encoder_attn, i) for i in range(args.decoder_layers) ]) def build_decoder_layer(self, args, no_encoder_attn=False, layer_id=0): layer = PrimeTransformerDecoderLayer(args, no_encoder_attn=no_encoder_attn, layer_id=layer_id) if getattr(args, "checkpoint_activations", False): layer = checkpoint_wrapper(layer) return layer
class TransformerEncoderBase(FairseqEncoder): """ Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, cfg, dictionary, embed_tokens): self.cfg = cfg super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)) self.encoder_layerdrop = cfg.encoder.layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = cfg.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt( embed_dim) self.embed_positions = (PositionalEmbedding( cfg.max_source_positions, embed_dim, self.padding_idx, learned=cfg.encoder.learned_pos, ) if not cfg.no_token_positional_embeddings else None) if cfg.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None if not cfg.adaptive_input and cfg.quant_noise.pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), cfg.quant_noise.pq, cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]) self.num_layers = len(self.layers) if cfg.encoder.normalize_before: self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None def build_encoder_layer(self, cfg): layer = transformer_layer.TransformerEncoderLayerBase(cfg) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward_embedding(self, src_tokens, token_embedding: Optional[torch.Tensor] = None): # embed tokens and positions if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward( self, src_tokens, src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ return self.forward_scriptable(src_tokens, src_lengths, return_all_hiddens, token_embeddings) # TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. # Current workaround is to add a helper function with different name and # call the helper function from scriptable Subclass. def forward_scriptable( self, src_tokens, src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any( ) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # account for padding while computing the representation if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) encoder_states = [] if return_all_hiddens: encoder_states.append(x) # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. src_lengths = src_tokens.ne(self.padding_idx).sum( dim=1, dtype=torch.int32).reshape(-1, 1).contiguous() return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_embedding": [encoder_embedding], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], "src_lengths": [src_lengths], } @torch.jit.export def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [ encoder_out["encoder_out"][0].index_select(1, new_order) ] if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select( 0, new_order) ] if len(encoder_out["encoder_embedding"]) == 0: new_encoder_embedding = [] else: new_encoder_embedding = [ encoder_out["encoder_embedding"][0].index_select(0, new_order) ] if len(encoder_out["src_tokens"]) == 0: src_tokens = [] else: src_tokens = [ (encoder_out["src_tokens"][0]).index_select(0, new_order) ] if len(encoder_out["src_lengths"]) == 0: src_lengths = [] else: src_lengths = [ (encoder_out["src_lengths"][0]).index_select(0, new_order) ] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": src_tokens, # B x T "src_lengths": src_lengths, # B x 1 } def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict["{}.embed_positions._float_tensor".format( name)] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.device_ = torch.device("cuda" if torch.cuda.is_available() else "cpu") # creating constant positional encoding table with the max source positions pe = torch.zeros(self.max_source_positions, embed_dim) position = torch.arange(0, self.max_source_positions).unsqueeze(1) div_term = torch.exp((torch.arange(0, embed_dim, 2, dtype=torch.float) * -(math.log(10000.0) / embed_dim))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) self.constant_positional_encoding = pe.to(self.device_) # position layers self.position_layers = args.position_layers self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [self.build_encoder_layer(args) for i in range(args.encoder_layers)] ) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.pos_weight = nn.Linear(embed_dim, embed_dim, bias=True).to(self.device_) def build_encoder_layer(self, args): return TransformerEncoderLayer(args) def forward_embedding( self, src_tokens, token_embedding: Optional[torch.Tensor] = None ): # embed tokens and positions if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None and self.position_layers is None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def position_attention(self, x, pos_table, sf_type): # T x B x C -> B x T x C x = x.transpose(0, 1) dim = x.size(-1) # pos_table : B x P x C -> B x C x P, scores: B x T (# words) x P (# positions) x = self.pos_weight(x) scores = torch.matmul(x, pos_table.transpose(-2, -1)) / math.sqrt(dim) if sf_type == 1: pos_attention = F.softmax(scores, dim=-1) else: pos_attention = F.gumbel_softmax(scores, tau=0.5, dim=-1) # [B x T (words) x P (positions)] x [B x P(positions) x C] -> B x T x C reordered_pos = torch.matmul(pos_attention, pos_table) # B x T x C -> T x B x C reordered_pos = reordered_pos.transpose(0, 1) return reordered_pos, pos_attention def p_position_attention(self, x, pos_table, sf_type): dim = x.size(-1) x_pos = self.constant_positional_encoding[:x.size(0)] # x_pos B x T x C x_pos = x_pos.repeat(x.size(1), 1).view(x.size(1), x.size(0), -1).to(x) # x_pos B x T x C multiplied with weight C x C x_pos_weight = self.pos_weight(x_pos) # B x T x C x_pos multiplied with B x C x P pos_table: scores B x T x P scores = torch.matmul(x_pos_weight, pos_table.transpose(-2, -1)) /math.sqrt(dim) if sf_type == 1: pos_attention = F.softmax(scores, dim=-1) else: pos_attention = F.gumbel_softmax(scores, dim=-1) # B x T x P with B x P x C becomes B x T x C reordered_pos = torch.matmul(pos_attention, pos_table) reordered_pos = reordered_pos.transpose(0,1) return reordered_pos, pos_attention def forward( self, src_tokens, src_lengths, max_target_position, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) num_sentences, src_len, d = x.size() # create position table of target positions (P x C ) where P stands for target positions pos_table = self.constant_positional_encoding[:max_target_position] # repeat position table for each of the sentences (B x P x C) pos_table = pos_table.repeat(num_sentences, 1).view(num_sentences, max_target_position, -1).to(x) if self.position_layers is not None: num_layers = len(self.position_layers) # stores position attention probabilities for each layer L x B x T x P probability = torch.empty(num_layers, num_sentences, src_len, max_target_position).to(x) else: probability = None # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None # encoder layers for idx, layer in enumerate(self.layers): if self.position_layers is not None and idx in self.position_layers: reordered_position, pos_attention = self.position_attention(x, pos_table,2) probability[self.position_layers.index(idx)] = pos_attention x = x + reordered_position x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, ), probability def forward_torchscript(self, net_input: Dict[str, Tensor]): """A TorchScript-compatible version of forward. Encoders which use additional arguments may want to override this method for TorchScript compatibility. """ if torch.jit.is_scripting(): return self.forward( src_tokens=net_input["src_tokens"], src_lengths=net_input["src_lengths"], max_target_position=net_input["max_target_position"] ) else: return self.forward_non_torchscript(net_input) @torch.jit.export def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ """ Since encoder_padding_mask and encoder_embedding are both of type Optional[Tensor] in EncoderOut, they need to be copied as local variables for Torchscript Optional refinement """ encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding new_encoder_out = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) new_encoder_padding_mask = ( encoder_padding_mask if encoder_padding_mask is None else encoder_padding_mask.index_select(0, new_order) ) new_encoder_embedding = ( encoder_embedding if encoder_embedding is None else encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out, # T x B x C encoder_padding_mask=new_encoder_padding_mask, # B x T encoder_embedding=new_encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 ) def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SrcFactorEncoder(FairseqEncoder): """ Based on TransformerEncoder but adds an additional embedding for src_factor. Changes have comments """ def __init__(self, args, dictionary, embed_tokens, src_factor_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop """ change the embedding dim to account for src_factor """ embed_dim = embed_tokens.embedding_dim + src_factor_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens """ save embedding """ self.src_factor_tokens = src_factor_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) self.embed_positions = (PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args) for i in range(args.encoder_layers) ]) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None def forward_embedding(self, src_tokens, src_factor_tokens): """ concat before passing to the rest """ x = self.embed_tokens(src_tokens) x = torch.cat([self.src_factor_tokens(src_factor_tokens), x], -1) x = embed = self.embed_scale * x if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def build_encoder_layer(self, args): return TransformerEncoderLayer(args) def forward(self, src_tokens, src_lengths, src_factor_tokens, src_factor_lengths, return_all_hiddens: bool = False): """ forward embedding signature """ x, encoder_embedding = self.forward_embedding(src_tokens, src_factor_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, ) @torch.jit.export def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ """ Since encoder_padding_mask and encoder_embedding are both of type Optional[Tensor] in EncoderOut, they need to be copied as local variables for Torchscript Optional refinement """ encoder_padding_mask: Optional[ Tensor] = encoder_out.encoder_padding_mask encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding new_encoder_out = (encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order)) new_encoder_padding_mask = ( encoder_padding_mask if encoder_padding_mask is None else encoder_padding_mask.index_select(0, new_order)) new_encoder_embedding = (encoder_embedding if encoder_embedding is None else encoder_embedding.index_select( 0, new_order)) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out, # T x B x C encoder_padding_mask=new_encoder_padding_mask, # B x T encoder_embedding=new_encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 ) def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict["{}.embed_positions._float_tensor".format( name)] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SingleShotTransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout = args.dropout self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args) for i in range(args.encoder_layers) ]) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None def build_encoder_layer(self, args): return TransformerEncoderLayer(args) def forward_embedding(self, src_tokens): # embed tokens and positions x = embed = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward( self, src_tokens, src_lengths, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, ) @torch.jit.export def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ new_encoder_out: Dict[str, Tensor] = {} new_encoder_out["encoder_out"] = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) new_encoder_out["encoder_padding_mask"] = ( encoder_out.encoder_padding_mask if encoder_out.encoder_padding_mask is None else encoder_out.encoder_padding_mask.index_select(0, new_order) ) new_encoder_out["encoder_embedding"] = ( encoder_out.encoder_embedding if encoder_out.encoder_embedding is None else encoder_out.encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out["encoder_out"], # T x B x C encoder_padding_mask=new_encoder_out["encoder_padding_mask"], # B x T encoder_embedding=new_encoder_out["encoder_embedding"], # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 ) def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.decoder_layerdrop = args.decoder_layerdrop self.only_drop_topk = args.only_drop_topk self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.embed_dim = embed_dim self.output_embed_dim = args.decoder_output_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None self.project_in_dim = (Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None) self.embed_positions = (PositionalEmbedding( args.max_target_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None self.cross_self_attention = getattr(args, "cross_self_attention", False) if self.decoder_layerdrop > 0.0: if self.only_drop_topk > 0: self.layers = PartLayerDropModuleList( p=self.decoder_layerdrop, top_k=self.only_drop_topk, layer_num=args.decoder_layers) else: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_decoder_layer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.num_layers = len(self.layers) if args.decoder_normalize_before and not getattr( args, "no_decoder_final_norm", False): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.project_out_dim = (Linear( embed_dim, self.output_embed_dim, bias=False) if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None) self.adaptive_softmax = None self.output_projection = None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif self.share_input_output_embed: self.output_projection = nn.Linear( self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False, ) self.output_projection.weight = self.embed_tokens.weight else: self.output_projection = nn.Linear(self.output_embed_dim, len(dictionary), bias=False) nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5) def build_decoder_layer(self, args, no_encoder_attn=False): return TransformerDecoderLayer(args, no_encoder_attn) def forward( self, prev_output_tokens, encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features( prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) if not features_only: x = self.output_layer(x) return x, extra def extract_features( self, prev_output_tokens, encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): return self.extract_features_scriptable( prev_output_tokens, encoder_out, incremental_state, full_context_alignment, alignment_layer, alignment_heads, ) """ A scriptable subclass of this class has an extract_features method and calls super().extract_features, but super() is not supported in torchscript. Aa copy of this function is made to be used in the subclass instead. """ def extract_features_scriptable( self, prev_output_tokens, encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): """ Similar to *forward* but only return features. Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., EMNLP 2019). Args: full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). alignment_layer (int, optional): return mean alignment over heads at this layer (default: last layer). alignment_heads (int, optional): only average alignment over this many heads (default: all heads). Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ if alignment_layer is None: alignment_layer = self.num_layers - 1 # embed positions positions = (self.embed_positions(prev_output_tokens, incremental_state=incremental_state) if self.embed_positions is not None else None) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.quant_noise is not None: x = self.quant_noise(x) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) self_attn_padding_mask: Optional[Tensor] = None if self.cross_self_attention or prev_output_tokens.eq( self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # decoder layers attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None x, layer_attn, _ = layer( x, encoder_out.encoder_out if encoder_out is not None else None, encoder_out.encoder_padding_mask if encoder_out is not None else None, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer)), ) inner_states.append(x) if layer_attn is not None and idx == alignment_layer: attn = layer_attn.float().to(x) if attn is not None: if alignment_heads is not None: attn = attn[:alignment_heads] # average probabilities over heads attn = attn.mean(dim=0) if self.layer_norm is not None: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) return x, {"attn": [attn], "inner_states": inner_states} def output_layer(self, features): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary return self.output_projection(features) else: return features def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions) def buffered_future_mask(self, tensor): dim = tensor.size(0) # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. if (self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) or self._future_mask.size(0) < dim): self._future_mask = torch.triu( utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1) self._future_mask = self._future_mask.to(tensor) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict["{}.embed_positions._float_tensor".format( name)] = torch.FloatTensor(1) if f"{name}.output_projection.weight" not in state_dict: if self.share_input_output_embed: embed_out_key = f"{name}.embed_tokens.weight" else: embed_out_key = f"{name}.embed_out" if embed_out_key in state_dict: state_dict[f"{name}.output_projection.weight"] = state_dict[ embed_out_key] if not self.share_input_output_embed: del state_dict[embed_out_key] for i in range(self.num_layers): # update layer norms layer_norm_map = { "0": "self_attn_layer_norm", "1": "encoder_attn_layer_norm", "2": "final_layer_norm", } for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layers.{}.layer_norms.{}.{}".format( name, i, old, m) if k in state_dict: state_dict["{}.layers.{}.{}.{}".format( name, i, new, m)] = state_dict[k] del state_dict[k] version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__( self, args, # Chỉ cần truyền vào cái này là đủ các tham số dòng lệnh cần thiết rồi. Còn nó là gì thì chưa được liệt kê dictionary, embed_tokens): self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim # Số chiều sau khi đã mã hóa thành nhị phân self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) # self.src_word_emb = embed_tokens # Sửa self.embed_positions = (PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args, ) for i in range(args.encoder_layers) ]) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def build_encoder_layer(self, args): layer = EncoderLayer(args.encoder_embed_dim, args.encoder_ffn_embed_dim, args.encoder_attention_heads, args.kdim, args.vdim, args.dropout) if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) return layer def forward(self, src_tokens, src_lengths, return_attns=False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ enc_slf_attn_list = [] # -- Prepare masks slf_attn_mask = get_attn_key_pad_mask(seq_k=src_tokens, seq_q=src_tokens, padding_idx=self.padding_idx) non_pad_mask = get_non_pad_mask(src_tokens, self.padding_idx) enc_output_real_real = self.embed_tokens(src_tokens) enc_output_phase = self.embed_positions(src_tokens) # Tạo ra src_pos src_pos = get_position(src_tokens) # Mở rộng thêm chiều thứ 2 # pos = torch.unsqueeze(torch.LongTensor(src_pos), 2) # size (64, 52, 1) pos = torch.unsqueeze(torch.cuda.LongTensor(src_pos), 2) # size (64, 52, 1) enc_output_phase = torch.mul( pos.float(), enc_output_phase ) # (64 * 52 * 1) * ( 64 * 52 * 512) --> 64 * 52 * 512 --> đây là nhân element-wise cos = torch.cos(enc_output_phase) sin = torch.sin(enc_output_phase) enc_output_real = enc_output_real_real * cos enc_output_phase = enc_output_real_real * sin for enc_layer in self.layers: enc_output_real, enc_output_phase, enc_slf_attn = enc_layer( enc_output_real, enc_output_phase, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask) if return_attns: enc_slf_attn_list += [enc_slf_attn] # if return_attns: # return enc_output_real,enc_output_phase, enc_slf_attn_list # return enc_output_real,enc_output_phase, # size = 64 , 512, 52 # # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # # `forward` so we use a dictionary instead. # # TorchScript does not support mixed values so the values are all lists. # # The empty list is equivalent to None. enc_output_real = enc_output_real.permute(1, 0, 2) enc_output_phase = enc_output_phase.permute(1, 0, 2) return { "slf_attn_mask": [slf_attn_mask], # B x T "enc_output_real": [enc_output_real], # B x T x C "enc_output_phase": [enc_output_phase], "encoder_states": enc_slf_attn_list, # List[T x B x C] "src_tokens": [src_tokens], "src_lengths": [], } @torch.jit.export def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ # slf_attn_mask if len(encoder_out["slf_attn_mask"]) == 0: new_slf_attn_mask = [] else: new_slf_attn_mask = [ encoder_out["slf_attn_mask"][0].index_select(0, new_order) ] # enc_output_real if len(encoder_out["enc_output_real"]) == 0: new_enc_output_real = [] else: new_enc_output_real = [ encoder_out["enc_output_real"][0].index_select(1, new_order) ] # enc_output_phase if len(encoder_out["enc_output_phase"]) == 0: new_enc_output_phase = [] else: new_enc_output_phase = [ encoder_out["enc_output_phase"][0].index_select(1, new_order) ] if len(encoder_out["src_tokens"]) == 0: src_tokens = [] else: src_tokens = [ (encoder_out["src_tokens"][0]).index_select(0, new_order) ] if len(encoder_out["src_lengths"]) == 0: src_lengths = [] else: src_lengths = [ (encoder_out["src_lengths"][0]).index_select(0, new_order) ] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "slf_attn_mask": new_slf_attn_mask, # T x B x C "enc_output_real": new_enc_output_real, # B x T "enc_output_phase": new_enc_output_phase, # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": src_tokens, # B x T "src_lengths": src_lengths, # B x 1 } def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict["{}.embed_positions._float_tensor".format( name)] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SpeechTransformerEncoder(TransformerEncoder): """ Transformer encoder consisting of 2D convolution layers and *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments conv_layers_before (~fairseq.speech_lstm.ConvBNReLU): convolutions before transformer layers input_size (int, optional): dimension of the input to the transformer before being projected to args.encoder_embed_dim """ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_context=None): super(TransformerEncoder, self).__init__(None) # no src dictionary self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = args.encoder_embed_dim self.max_source_positions = args.max_source_positions self.conv_layers_before = conv_layers_before self.fc0 = Linear(input_size, embed_dim) if input_size != embed_dim else None self.embed_positions = (PositionalEmbedding( self.output_lengths(self.max_source_positions), embed_dim, 0, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args) for i in range(args.encoder_layers) ]) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.transformer_context = transformer_context def output_lengths(self, in_lengths): return (in_lengths if self.conv_layers_before is None else self.conv_layers_before.output_lengths(in_lengths)) def get_attn_mask(self, in_lengths): """ Create attention mask according to sequence lengths and transformer context. Args: in_lengths (LongTensor): lengths of each input sequence of shape `(batch)` Returns: attn_mask (ByteTensor|BoolTensor, optional): self-attention mask of shape `(tgt_len, src_len)`, where `tgt_len` is the length of output and `src_len` is the length of input, though here both are equal to `seq_len`. `attn_mask[tgt_i, src_j] = 1` means that when calculating the embedding for `tgt_i`, we exclude (mask out) `src_j`. """ if (self.transformer_context is None or (self.transformer_context[0] is None and self.transformer_context[1] is None)): return None max_len = in_lengths.data.max() all_ones = in_lengths.ones([max_len, max_len], dtype=torch.bool) # at this point left and right context cannot be both None if self.transformer_context[0] is None: # mask is a triu matrix return all_ones.triu(self.transformer_context[1] + 1) if self.transformer_context[1] is None: # mask is a tril matrix return all_ones.tril(-self.transformer_context[0] - 1) return (all_ones.triu(self.transformer_context[1] + 1) | all_ones.tril(-self.transformer_context[0] - 1)) def forward( self, src_tokens, src_lengths, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ if self.conv_layers_before is not None: x, src_lengths, encoder_padding_mask = self.conv_layers_before( src_tokens, src_lengths) else: x, encoder_padding_mask = ( src_tokens, ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1))) x = self.dropout_module(x) if self.fc0 is not None: x = self.fc0(x) if self.embed_positions is not None: # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings x = x + self.embed_positions((~encoder_padding_mask).int()) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) elif self.embed_positions is not None: # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings x = x + self.embed_positions((~encoder_padding_mask).int()) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) # B x T x C -> T x B x C x = x.transpose(0, 1) attn_mask = self.get_attn_mask(src_lengths) encoder_states = [] # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask, attn_mask=attn_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `foward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T "encoder_embedding": [], "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], "src_lengths": [], } def max_positions(self): """Maximum input length supported by the encoder.""" return self.max_source_positions
class TransformerSentenceEncoder(nn.Module): """ Implementation for a Bi-directional Transformer based Sentence Encoder used in BERT/XLM style pre-trained models. This first computes the token embedding using the token embedding matrix, position embeddings (if specified) and segment embeddings (if specified). After applying the specified number of TransformerEncoderLayers, it outputs all the internal states of the encoder as well as the final representation associated with the first token (usually CLS token). Input: - tokens: B x T matrix representing sentences - segment_labels: B x T matrix representing segment label for tokens Output: - a tuple of the following: - a list of internal model states used to compute the predictions where each tensor has shape T x B x C - sentence representation associated with first input token in format B x C. """ def __init__( self, padding_idx: int, vocab_size: int, num_encoder_layers: int = 6, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, layerdrop: float = 0.0, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, offset_positions_by_padding: bool = True, encoder_normalize_before: bool = False, apply_bert_init: bool = False, activation_fn: str = "relu", learned_pos_embedding: bool = True, embed_scale: float = None, freeze_embeddings: bool = False, n_trans_layers_to_freeze: int = 0, export: bool = False, traceable: bool = False, q_noise: float = 0.0, qn_block_size: int = 8, ) -> None: super().__init__() self.padding_idx = padding_idx self.vocab_size = vocab_size self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__) self.layerdrop = layerdrop self.max_seq_len = max_seq_len self.embedding_dim = embedding_dim self.num_segments = num_segments self.use_position_embeddings = use_position_embeddings self.apply_bert_init = apply_bert_init self.learned_pos_embedding = learned_pos_embedding self.traceable = traceable self.tpu = False # whether we're on TPU self.embed_tokens = self.build_embedding(self.vocab_size, self.embedding_dim, self.padding_idx) self.embed_scale = embed_scale if q_noise > 0: self.quant_noise = apply_quant_noise_( nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), q_noise, qn_block_size, ) else: self.quant_noise = None self.segment_embeddings = (nn.Embedding( self.num_segments, self.embedding_dim, padding_idx=None) if self.num_segments > 0 else None) self.embed_positions = ( PositionalEmbedding( 514, #self.max_seq_len, self.embedding_dim, padding_idx= None, #(self.padding_idx if offset_positions_by_padding else None), learned=self.learned_pos_embedding, ) if self.use_position_embeddings else None) if self.layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_transformer_sentence_encoder_layer( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=self.dropout_module.p, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, export=export, q_noise=q_noise, qn_block_size=qn_block_size, ) for _ in range(num_encoder_layers) ]) self.crossthought_layers = nn.ModuleList([]) self.crossthought_layers.extend([ self.build_transformer_sentence_encoder_layer( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=self.dropout_module.p, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, export=export, q_noise=q_noise, qn_block_size=qn_block_size, ) for _ in range(2) ]) self.pre_train = False if 'CTPRETRAIN' not in os.environ else os.environ[ 'CTPRETRAIN'] == 'True' self.short_seq_len = 64 self.sent_emb_num = 5 if encoder_normalize_before: self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) else: self.emb_layer_norm = None # Apply initialization of model params after building the model if self.apply_bert_init: self.apply(init_bert_params) def freeze_module_params(m): if m is not None: for p in m.parameters(): p.requires_grad = False if freeze_embeddings: freeze_module_params(self.embed_tokens) freeze_module_params(self.segment_embeddings) freeze_module_params(self.embed_positions) freeze_module_params(self.emb_layer_norm) for layer in range(n_trans_layers_to_freeze): freeze_module_params(self.layers[layer]) def build_embedding(self, vocab_size, embedding_dim, padding_idx): return nn.Embedding(vocab_size, embedding_dim, padding_idx) def build_transformer_sentence_encoder_layer( self, embedding_dim, ffn_embedding_dim, num_attention_heads, dropout, attention_dropout, activation_dropout, activation_fn, export, q_noise, qn_block_size, ): return TransformerSentenceEncoderLayer( embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, export=export, q_noise=q_noise, qn_block_size=qn_block_size, ) def prepare_for_tpu_(self, **kwargs): self.tpu = True def forward( self, tokens: torch.Tensor, segment_labels: torch.Tensor = None, last_state_only: bool = False, positions: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # compute padding mask. This is needed for multi-head attention if self.pre_train: batch_size_raw, seq_len_raw = tokens.size() assert batch_size_raw == 1 if seq_len_raw % self.short_seq_len == 0: tokens = tokens.view(-1, self.short_seq_len) else: pad_num = (seq_len_raw // self.short_seq_len + 1) * self.short_seq_len - seq_len_raw tokens = F.pad(tokens, (0, pad_num), value=self.padding_idx).view( -1, self.short_seq_len) batch_size, seq_len = tokens.size() tokens = F.pad(tokens, (self.sent_emb_num, 0)) tokens = tokens.detach() else: end_index = torch.nonzero(tokens == 2, as_tuple=False)[:, 1] + 1 tokens_new = [] batch_size_raw, seq_len_raw = tokens.size() for j in range(batch_size_raw): tokens_new.append(tokens[j][:end_index[2 * j]]) tokens_new.append(tokens[j][end_index[2 * j]:end_index[2 * j + 1]]) seq_len = max(v.size(0) for v in tokens_new) tokens = tokens.new(len(tokens_new), seq_len).fill_(self.padding_idx) for j, v in enumerate(tokens_new): tokens[j][:v.size(0)] = v if self.sent_emb_num > 1: tokens = F.pad(tokens, (self.sent_emb_num - 1, 0)) padding_mask = tokens.eq(self.padding_idx) if not self.traceable and not self.tpu and not padding_mask.any(): padding_mask = None if token_embeddings is not None: x = token_embeddings else: x = self.embed_tokens(tokens) if self.embed_scale is not None: x = x * self.embed_scale if self.embed_positions is not None: mask_pos = tokens.ne(self.padding_idx).int() if self.pre_train: mask_pos[:, 0] = 0 mask_pos[:, self.sent_emb_num] = random.randint( 1, 512 - self.sent_emb_num - self.short_seq_len - 1) positions = torch.cumsum(mask_pos, dim=1).type_as(mask_pos).long() x = x + self.embed_positions(tokens, positions=positions) if self.segment_embeddings is not None and segment_labels is not None: x = x + self.segment_embeddings(segment_labels) if self.quant_noise is not None: x = self.quant_noise(x) if self.emb_layer_norm is not None: x = self.emb_layer_norm(x) x = self.dropout_module(x) # account for padding while computing the representation if padding_mask is not None: x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) inner_states = [] if not last_state_only: inner_states.append(x) for layer in self.layers: x, _ = layer(x, self_attn_padding_mask=padding_mask) if not last_state_only: inner_states.append(x) if self.pre_train: for layer_id, layer in enumerate(self.crossthought_layers): if layer_id == 0: x_sent = x[:self.sent_emb_num].clone().transpose(0, 1) x_sent, _ = layer(x_sent) x[:self.sent_emb_num] = x_sent.transpose(0, 1) else: x, _ = layer(x, self_attn_padding_mask=padding_mask) x = x[self.sent_emb_num:].transpose(0, 1) x = x.contiguous().view(1, -1, x.size(-1))[:, :seq_len_raw].transpose( 0, 1) else: x_sent = x[:self.sent_emb_num].view(self.sent_emb_num, int(x.size(1) / 2), 2, x.size(-1)).clone().transpose( 0, 2) x_sent = x_sent.contiguous().view(2, -1, x.size(-1)) x_sent, _ = self.crossthought_layers[0](x_sent) x_sent = x_sent.mean(0).view(-1, self.sent_emb_num, x.size(-1)) x = x_sent.transpose(0, 1) x[0] = x[:self.sent_emb_num].mean(0) sentence_rep = x[0, :, :] if last_state_only: inner_states = [x] if self.traceable: return torch.stack(inner_states), sentence_rep else: return inner_states, sentence_rep
class DeliberaterDecoder(TransformerDecoder): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(args, dictionary, embed_tokens, no_encoder_attn) if self.decoder_layerdrop > 0.0: self.deliberate_layers = LayerDropModuleList( p=self.decoder_layerdrop) else: self.deliberate_layers = nn.ModuleList([]) self.deliberate_layers.extend([ self.build_deliberate_layer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) def build_deliberate_layer(self, args, no_encoder_attn): layer = TransformerDeliberaterLayer(args, no_encoder_attn) if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) return layer def forward( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features( prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) if not features_only: x = self.output_layer(x) extra["dec_out"] = self.output_layer(extra["dec_out"]) return x, extra def extract_features_scriptable( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): if alignment_layer is None: alignment_layer = self.num_layers - 1 positions = None if self.embed_positions is not None: positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.quant_noise is not None: x = self.quant_noise(x) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) residual = x # B x T x C -> T x B x C x = x.transpose(0, 1) residual = residual.transpose(0, 1) self_attn_padding_mask: Optional[Tensor] = None if self.cross_self_attention or prev_output_tokens.eq( self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # logger.info("Info of self_attn_padding_mask:{}".format(self_attn_padding_mask)) # decoder layers attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None x, layer_attn, _ = layer( x, encoder_out["encoder_out"][0] if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) else None, encoder_out["encoder_padding_mask"][0] if (encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0) else None, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer)), ) inner_states.append(x) if layer_attn is not None and idx == alignment_layer: attn = layer_attn.float().to(x) if attn is not None: if alignment_heads is not None: attn = attn[:alignment_heads] # average probabilities over heads attn = attn.mean(dim=0) if self.layer_norm is not None: x = self.layer_norm(x) # deliberater_layer dec_out = x de_attn: Optional[Tensor] = None de_inner_states: List[Optional[Tensor]] = [residual] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(residual) else: self_attn_mask = None residual, layer_attn, _ = layer._forward( residual, encoder_out["encoder_out"][0] if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) else None, encoder_out["encoder_padding_mask"][0] if (encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0) else None, dec_out, self_attn_padding_mask, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer))) de_inner_states.append(residual) if layer_attn is not None and idx == alignment_layer: de_attn = layer_attn.float().to(residual) if de_attn is not None: if alignment_heads is not None: de_attn = de_attn[:alignment_heads] de_attn = de_attn.mean(dim=0) if self.layer_norm is not None: residual = self.layer_norm(residual) # T x B x C -> B x T x C dec_out = dec_out.transpose(0, 1) residual = residual.transpose(0, 1) if self.project_out_dim is not None: dec_out = self.project_out_dim(dec_out) residual = self.project_out_dim(residual) return residual, { "attn": [attn], "inner_states": inner_states, "de_attn": [de_attn], "de_inner_states": de_inner_states, "dec_out": dec_out } def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict["{}.embed_positions._float_tensor".format( name)] = torch.FloatTensor(1) if f"{name}.output_projection.weight" not in state_dict: if self.share_input_output_embed: embed_out_key = f"{name}.embed_tokens.weight" else: embed_out_key = f"{name}.embed_out" if embed_out_key in state_dict: state_dict[f"{name}.output_projection.weight"] = state_dict[ embed_out_key] if not self.share_input_output_embed: del state_dict[embed_out_key] for i in range(self.num_layers): # update layer norms layer_norm_map = { "0": "self_attn_layer_norm", "1": "encoder_attn_layer_norm", "2": "final_layer_norm", "3": "decoder_attn_layer_norm", } for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layers.{}.layer_norms.{}.{}".format( name, i, old, m) if k in state_dict: state_dict["{}.layers.{}.{}.{}".format( name, i, new, m)] = state_dict[k] del state_dict[k] version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SpeechTransformerEncoderBase(TransformerEncoderBase): """ Transformer encoder consisting of 2D convolution layers and *cfg.encoder.layers* layers. Each layer is a :class:`TransformerWithRelativePositionalEmbeddingEncoderLayer`. Args: cfg (FairseqDataclass): model configuration pre_encoder (~espresso.modules.ConvBNReLU): convolutions before transformer layers input_size (int, optional): dimension of the input to the transformer before being projected to cfg.encoder.embed_dim """ def __init__( self, cfg, return_fc=False, pre_encoder=None, input_size=83, transformer_context=None, ): self.cfg = cfg super(TransformerEncoderBase, self).__init__(None) # no src dictionary self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) ) self.encoder_layerdrop = cfg.encoder.layerdrop self.return_fc = return_fc embed_dim = cfg.encoder.embed_dim self.max_source_positions = cfg.max_source_positions self.pre_encoder = pre_encoder self.fc0 = Linear(input_size, embed_dim) if input_size != embed_dim else None self.embed_scale = ( 1.0 if cfg.no_scale_embedding or self.fc0 is not None # always diable scaling if fc0 is present else math.sqrt(embed_dim) ) if ( not cfg.no_token_positional_embeddings and cfg.encoder.relative_positional_embeddings ): logger.info( "disabled encoder's absolute positional embeddings as encoder_relative_positional_embeddings is True." ) self.embed_positions = ( PositionalEmbedding( self.output_lengths(self.max_source_positions), embed_dim, 0, learned=cfg.encoder.learned_pos, ) if not cfg.no_token_positional_embeddings and not cfg.encoder.relative_positional_embeddings else None ) if cfg.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None if not cfg.adaptive_input and cfg.quant_noise.pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), cfg.quant_noise.pq, cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None if cfg.encoder.relative_positional_embeddings: if cfg.encoder.learned_pos: rel_pos_embed_list = [ RelativePositionalEmbedding( cfg.encoder.embed_dim, padding_idx=None, max_size=self.output_lengths(cfg.max_source_positions), learned=True, ) for _ in range(cfg.encoder.layers) ] else: rel_pos_embed = RelativePositionalEmbedding( cfg.encoder.embed_dim, padding_idx=None, max_size=None, learned=False, ) # single instance referenced across layers rel_pos_embed_list = [rel_pos_embed] * cfg.encoder.layers else: rel_pos_embed_list = [None] * cfg.encoder.layers if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [ self.build_encoder_layer( cfg, positional_embedding=rel_pos_embed_list[i] ) for i in range(cfg.encoder.layers) ] ) self.num_layers = len(self.layers) if cfg.encoder.normalize_before and cfg.encoder.layer_type != "conformer": self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None self.transformer_context = transformer_context def build_encoder_layer( self, cfg, positional_embedding: Optional[RelativePositionalEmbedding] = None ): if cfg.encoder.layer_type == "transformer": layer_cls = TransformerWithRelativePositionalEmbeddingEncoderLayerBase elif cfg.encoder.layer_type == "conformer": layer_cls = ConformerWithRelativePositionalEmbeddingEncoderLayerBase else: raise NotImplementedError layer = layer_cls( cfg, return_fc=self.return_fc, positional_embedding=positional_embedding ) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def output_lengths(self, in_lengths): return ( in_lengths if self.pre_encoder is None else self.pre_encoder.output_lengths(in_lengths) ) def get_attn_mask(self, in_lengths): """ Create attention mask according to sequence lengths and transformer context. Args: in_lengths (LongTensor): lengths of each input sequence of shape `(batch)` Returns: attn_mask (ByteTensor|BoolTensor, optional): self-attention mask of shape `(tgt_len, src_len)`, where `tgt_len` is the length of output and `src_len` is the length of input, though here both are equal to `seq_len`. `attn_mask[tgt_i, src_j] = 1` means that when calculating the embedding for `tgt_i`, we exclude (mask out) `src_j`. """ if self.transformer_context is None or ( self.transformer_context[0] is None and self.transformer_context[1] is None ): return None max_len = in_lengths.data.max() all_ones = in_lengths.ones([max_len, max_len], dtype=torch.bool) # at this point left and right context cannot be both None if self.transformer_context[0] is None: # mask is a triu matrix return all_ones.triu(self.transformer_context[1] + 1) if self.transformer_context[1] is None: # mask is a tril matrix return all_ones.tril(-self.transformer_context[0] - 1) return all_ones.triu(self.transformer_context[1] + 1) | all_ones.tril( -self.transformer_context[0] - 1 ) def forward( self, src_tokens, src_lengths, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ return self.forward_scriptable(src_tokens, src_lengths, return_all_hiddens) # TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. # Current workaround is to add a helper function with different name and # call the helper function from scriptable Subclass. def forward_scriptable( self, src_tokens, src_lengths, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ if self.pre_encoder is not None: x, src_lengths, encoder_padding_mask = self.pre_encoder( src_tokens, src_lengths ) else: x, encoder_padding_mask = ( src_tokens, ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)), ) has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() if self.fc0 is not None: x = self.dropout_module(x) x = self.fc0(x) x = self.embed_scale * x if self.embed_positions is not None: # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings x = x + self.embed_positions((~encoder_padding_mask).int()) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) # account for padding while computing the representation if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) encoder_states = [] fc_results = [] if return_all_hiddens: encoder_states.append(x) attn_mask = self.get_attn_mask(src_lengths) # nested tensor and BT enable layer = self.layers[0] BT_flag = False NT_flag = False # torch version check, BT>=1.12.0 and NT>=1.13.0.dev20220613 # internal format is '1.13.0a0+fb' # external format is '1.13.0.dev20220613'(cpu&gpu) for nightly or "1.11.0"(cpu) or '1.11.0+cu102'(gpu) for stable BT_version = False NT_version = False if "fb" in torch.__version__: BT_version = True NT_version = True else: if "+" in torch.__version__: torch_version = torch.__version__.split("+")[0] else: torch_version = torch.__version__ torch_version = torch_version.split(".") int_version = ( int(torch_version[0]) * 1000 + int(torch_version[1]) * 10 + int(torch_version[2]) ) if len(torch_version) == 3: if int_version >= 1120: BT_version = True if int_version >= 1131: NT_version = True elif len(torch_version) == 4: if int_version >= 1130: BT_version = True # Consider _nested_tensor_from_mask_left_aligned is landed after "20220613" if int_version >= 1131 or ( int_version == 1130 and torch_version[3][3:] >= "20220613" ): NT_version = True if ( BT_version and x.dim() == 3 and layer.load_to_BT and not layer.return_fc and layer.can_use_fastpath and not layer.training and not layer.ever_training and not layer.cfg_checkpoint_activations ): # Batch first can not be justified but needs user to make sure x = x.transpose(0, 1) # Check mask conditions for nested tensor if NT_version: if ( encoder_padding_mask is not None and torch._nested_tensor_from_mask_left_aligned( x, encoder_padding_mask.logical_not() ) ): if not torch.is_grad_enabled() or not x.requires_grad: x = torch._nested_tensor_from_mask( x, encoder_padding_mask.logical_not() ) NT_flag = True BT_flag = True # encoder layers if NT_flag: processing_mask = None else: processing_mask = encoder_padding_mask encoder_padding_mask_out = processing_mask if has_pads else None for layer in self.layers: lr = layer( x, encoder_padding_mask=encoder_padding_mask_out, attn_mask=attn_mask, ) if isinstance(lr, tuple) and len(lr) == 2: x, fc_result = lr else: x = lr fc_result = None if return_all_hiddens and not torch.jit.is_scripting(): assert encoder_states is not None encoder_states.append(x) fc_results.append(fc_result) # change back to non-nested and Batch second if NT_flag: x = x.to_padded_tensor(0.0) if NT_flag or BT_flag: x = x.transpose(0, 1) if self.layer_norm is not None: x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T "encoder_embedding": [], "encoder_states": encoder_states, # List[T x B x C] "fc_results": fc_results, # List[T x B x C] "src_tokens": [], "src_lengths": [src_lengths], # List[B] } def max_positions(self): """Maximum input length supported by the encoder.""" return self.max_source_positions def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) if len(name) > 0: name += "." old_pattern = "{}conv_layers_before".format(name) new_pattern = "{}pre_encoder".format(name) old_keys = [] for k in state_dict: if k.startswith(old_pattern): old_keys.append(k) for old_key in old_keys: new_key = old_key.replace(old_pattern, new_pattern, 1) state_dict[new_key] = state_dict.pop(old_key) return state_dict
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop self.encoder_num_pos = args.encoder_num_pos # number of unique pos tags self.pos_embed_dim = args.encoder_pos_embed_dim # pos vector dimension self.embed_tokens = embed_tokens if self.encoder_num_pos != 0: self.pos_tokens = self.build_pos_embed(dictionary, self.embed_tokens, args.encoder_embed_dim) total_dim = self.embed_tokens.embedding_dim + self.pos_tokens.embedding_dim else: self.pos_tokens = None total_dim = self.embed_tokens.embedding_dim embed_dim = self.embed_tokens.embedding_dim self.padding_idx = self.embed_tokens.padding_idx self.max_source_positions = args.max_source_positions #self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) self.train_en_file = self.load_train() self.pos_tag = self.load_pos_tag() if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(total_dim) else: self.layernorm_embedding = None if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(total_dim, total_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [self.build_encoder_layer(args) for i in range(args.encoder_layers)] ) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(total_dim) else: self.layer_norm = None def load_train(self): filename1 = '............/combine.bpe.en' #give exact path of combine.bpe.en file file1 = open(filename1, encoding="utf8").readlines() train_en = [] for lines in file1: train_en.append(lines) return train_en def load_pos_tag(self): filename1 = '............/combine.numbers.en' #give exact path of combine.numbers.en file file1 = open(filename1, encoding="utf8").readlines() pos_tag = [] for lines in file1: pos_tag.append(lines) return pos_tag def build_pos_embed(self, dictionary, embed_tokens, embed_dim): num_tags = self.encoder_num_pos pos_tokens = torch.nn.Embedding(num_tags, self.pos_embed_dim) pos_tokens = utils.normalize_embedding(pos_tokens, 2) return pos_tokens def build_encoder_layer(self, args): return TransformerEncoderLayer(args) def forward_embedding(self, src_tokens): # embed tokens and positions x = embed = self.embed_scale * self.embed_tokens(src_tokens) l = src_tokens.tolist() # make a list length = src_token length tagList = self.processBatch(l) # batch size to 512 size list if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.pos_tokens != 0: x_pos = self.pos_tokens(tagList) # src_tokens->tagList x = torch.cat((x, x_pos), dim=2) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def processBatch(self, a_sentence): # a_sentence : [[],[],[]] 32 batch = [] # ag lists for all sentences in batch(32) for i in range(len(a_sentence)): single = [] # [] tag list for a single sentence temp_l = a_sentence[i] # [] one from [[],[],......] sentence = self.createSentence(temp_l) indexS = self.getIndex(sentence) postags = self.getPosTagList(sentence, indexS) if(postags != -1): # tags only for words, add tags for <pad> <s> padcount = 0 # <pad> for i in range(len(temp_l)): if temp_l[i] == 1: padcount+=1 if padcount>0: for j in range(padcount): single.append(39) # 39 for the word <pad> single = single + self.convert(postags) single.append(40) # 40 for word <s> batch.append(single) else: single = self.convert(postags) single.append(40) # 40 for word <s> batch.append(single) else: for ele in temp_l: single.append(39) # replace with pad tag batch.append(single) newbatch = torch.tensor(batch, dtype=torch.int64).cuda() # converting to a long tensor return newbatch def convert(self, array): new = [] for ele in array: new.append(int(ele)) return new def createSentence(self, sen): # [1,1,3,4,5,3,4,2] sentence = "" array = sen[:len(sen)-1] # remove <s> and <pad> for i in range(len(array)): if(array[i]!=1): if(i != len(array)-1): sentence += self.dictionary[array[i]]+" " else: sentence += self.dictionary[array[i]]+'\n' # add a new line symbol to match from the file return sentence def getIndex(self, sentence): # get the index of the sentence from the training dataset train = self.train_en_file if sentence in train: indexSentence = train.index(sentence) return indexSentence else: return -1 def getPosTagList(self, sentence, index): posTags = self.pos_tag # [[],[],[]] if(index!=-1): return posTags[index].split() else: return -1 def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, ) @torch.jit.export def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ """ Since encoder_padding_mask and encoder_embedding are both of type Optional[Tensor] in EncoderOut, they need to be copied as local variables for Torchscript Optional refinement """ encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding new_encoder_out = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) new_encoder_padding_mask = ( encoder_padding_mask if encoder_padding_mask is None else encoder_padding_mask.index_select(0, new_order) ) new_encoder_embedding = ( encoder_embedding if encoder_embedding is None else encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out, # T x B x C encoder_padding_mask=new_encoder_padding_mask, # B x T encoder_embedding=new_encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 ) def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerEncoderBase(FairseqEncoder): """ Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False): self.cfg = cfg super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)) self.encoder_layerdrop = cfg.encoder.layerdrop self.return_fc = return_fc embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = cfg.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt( embed_dim) self.embed_positions = (PositionalEmbedding( cfg.max_source_positions, embed_dim, self.padding_idx, learned=cfg.encoder.learned_pos, ) if not cfg.no_token_positional_embeddings else None) if cfg.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None if not cfg.adaptive_input and cfg.quant_noise.pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), cfg.quant_noise.pq, cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]) self.num_layers = len(self.layers) if cfg.encoder.normalize_before: self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None def build_encoder_layer(self, cfg): layer = transformer_layer.TransformerEncoderLayerBase( cfg, return_fc=self.return_fc) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward_embedding(self, src_tokens, token_embedding: Optional[torch.Tensor] = None): # embed tokens and positions if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward( self, src_tokens, src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ return self.forward_scriptable(src_tokens, src_lengths, return_all_hiddens, token_embeddings) # TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. # Current workaround is to add a helper function with different name and # call the helper function from scriptable Subclass. def forward_scriptable( self, src_tokens, src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any( ) x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # account for padding while computing the representation if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) encoder_states = [] fc_results = [] if return_all_hiddens: encoder_states.append(x) # nested tensor and BT enable layer = self.layers[0] BT_flag = False NT_flag = False # torch version check, BT>=1.12.0 and NT>=1.13.0.dev20220613 # internal format is '1.13.0a0+fb' # external format is '1.13.0.dev20220613'(cpu&gpu) for nightly or "1.11.0"(cpu) or '1.11.0+cu102'(gpu) for stable BT_version = False NT_version = False if "fb" in torch.__version__: BT_version = True NT_version = True else: if "+" in torch.__version__: torch_version = torch.__version__.split("+")[0] else: torch_version = torch.__version__ torch_version = torch_version.split(".") int_version = (int(torch_version[0]) * 1000 + int(torch_version[1]) * 10 + int(torch_version[2])) if len(torch_version) == 3: if int_version >= 1120: BT_version = True if int_version >= 1131: NT_version = True elif len(torch_version) == 4: if int_version >= 1130: BT_version = True # Consider _nested_tensor_from_mask_left_aligned is landed after "20220613" if int_version >= 1131 or (int_version == 1130 and torch_version[3][3:] >= "20220613"): NT_version = True if (BT_version and x.dim() == 3 and layer.load_to_BT and not layer.return_fc and layer.can_use_fastpath and not layer.training and not layer.ever_training and not layer.cfg_checkpoint_activations): # Batch first can not be justified but needs user to make sure x = x.transpose(0, 1) # Check mask conditions for nested tensor if NT_version: if (encoder_padding_mask is not None and torch._nested_tensor_from_mask_left_aligned( x, encoder_padding_mask.logical_not())): if not torch.is_grad_enabled() or not x.requires_grad: x = torch._nested_tensor_from_mask( x, encoder_padding_mask.logical_not()) NT_flag = True BT_flag = True # encoder layers if NT_flag: processing_mask = None else: processing_mask = encoder_padding_mask encoder_padding_mask_out = processing_mask if has_pads else None for layer in self.layers: lr = layer(x, encoder_padding_mask=encoder_padding_mask_out) if isinstance(lr, tuple) and len(lr) == 2: x, fc_result = lr else: x = lr fc_result = None if return_all_hiddens and not torch.jit.is_scripting(): assert encoder_states is not None encoder_states.append(x) fc_results.append(fc_result) # change back to non-nested and Batch second if NT_flag: x = x.to_padded_tensor(0.0) if NT_flag or BT_flag: x = x.transpose(0, 1) if self.layer_norm is not None: x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. src_lengths = (src_tokens.ne(self.padding_idx).sum( dim=1, dtype=torch.int32).reshape(-1, 1).contiguous()) return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_embedding": [encoder_embedding], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "fc_results": fc_results, # List[T x B x C] "src_tokens": [], "src_lengths": [src_lengths], } @torch.jit.export def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [ encoder_out["encoder_out"][0].index_select(1, new_order) ] if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select( 0, new_order) ] if len(encoder_out["encoder_embedding"]) == 0: new_encoder_embedding = [] else: new_encoder_embedding = [ encoder_out["encoder_embedding"][0].index_select(0, new_order) ] if len(encoder_out["src_tokens"]) == 0: src_tokens = [] else: src_tokens = [ (encoder_out["src_tokens"][0]).index_select(0, new_order) ] if len(encoder_out["src_lengths"]) == 0: src_lengths = [] else: src_lengths = [ (encoder_out["src_lengths"][0]).index_select(0, new_order) ] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": src_tokens, # B x T "src_lengths": src_lengths, # B x 1 } @torch.jit.export def _reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """Dummy re-order function for beamable enc-dec attention""" return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict["{}.embed_positions._float_tensor".format( name)] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TrexEncoder(FairseqEncoder): """ Adapted from Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (dict of torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens, embed_bytes): self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_bytes.embedding_dim # assume the padding will be the same for a sequences, self.padding_idx = embed_bytes.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens if not self.args.input_combine == 'drop_bytes': self.embed_bytes = embed_bytes if self.args.input_combine == 'cnn': self.byte_combine = ByteCombineCNN(embed_dim, embed_dim) elif self.args.input_combine == 'drop_bytes': self.byte_combine = None else: # input_combine == 'sum' self.byte_combine = ByteCombineSUM() self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args) for _ in range(args.encoder_layers) ]) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def build_encoder_layer(self, args): layer = TransformerEncoderLayer(args) checkpoint = getattr(args, "checkpoint_activations", False) if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = (getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0) layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward_embedding(self, src_tokens): # embed tokens (static) token_embedding = self.embed_tokens[configs.static_field]( src_tokens[configs.static_field]) # embed auxiliary annotations (inst_pos, op_pos, arch) for field in configs.aux_fields: token_embedding += self.embed_tokens[field](src_tokens[field]) if self.byte_combine is not None: byte_embedding_stack = [] for field in configs.byte_fields: byte_embedding_stack.append(self.embed_bytes( src_tokens[field])) byte_embedding = self.byte_combine( torch.stack(byte_embedding_stack, dim=2)) x = embed = self.embed_scale * (token_embedding + byte_embedding) else: x = embed = self.embed_scale * token_embedding if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward( self, src_tokens, src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): dictionary of tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ return self.forward_scriptable(src_tokens, src_lengths, return_all_hiddens) # TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. # Current workaround is to add a helper function with different name and # call the helper function from scriptable Subclass. def forward_scriptable(self, src_tokens, src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ # compute padding mask encoder_padding_mask = src_tokens[configs.byte_fields[0]].eq( self.padding_idx) # padding is from byte sequences has_pads = (src_tokens[configs.byte_fields[0]].device.type == "xla" or encoder_padding_mask.any()) x, encoder_embedding = self.forward_embedding(src_tokens) # account for padding while computing the representation if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) encoder_states = [] if return_all_hiddens: encoder_states.append(x) # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_embedding": [encoder_embedding], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], "src_lengths": [], } @torch.jit.export def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [ encoder_out["encoder_out"][0].index_select(1, new_order) ] if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select( 0, new_order) ] if len(encoder_out["encoder_embedding"]) == 0: new_encoder_embedding = [] else: new_encoder_embedding = [ encoder_out["encoder_embedding"][0].index_select(0, new_order) ] if len(encoder_out["src_tokens"]) == 0: src_tokens = [] else: src_tokens = [ (encoder_out["src_tokens"][0]).index_select(0, new_order) ] if len(encoder_out["src_lengths"]) == 0: src_lengths = [] else: src_lengths = [ (encoder_out["src_lengths"][0]).index_select(0, new_order) ] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": src_tokens, # B x T "src_lengths": src_lengths, # B x 1 } def max_positions(self): """Maximum input length supported by the encoder.""" return self.max_source_positions def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) self.embed_positions = (PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args) for i in range(args.encoder_layers) ]) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def build_encoder_layer(self, args): layer = TransformerEncoderLayer(args) if getattr(args, "checkpoint_activations", False): layer = checkpoint_wrapper(layer) return layer def forward_embedding(self, src_tokens, token_embedding: Optional[torch.Tensor] = None): # embed tokens and positions if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward( self, src_tokens, src_lengths, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `foward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_embedding": [encoder_embedding], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], "src_lengths": [], } @torch.jit.export def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [ encoder_out["encoder_out"][0].index_select(1, new_order) ] if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select( 0, new_order) ] if len(encoder_out["encoder_embedding"]) == 0: new_encoder_embedding = [] else: new_encoder_embedding = [ encoder_out["encoder_embedding"][0].index_select(0, new_order) ] if len(encoder_out["src_tokens"]) == 0: src_tokens = [] else: src_tokens = [ (encoder_out["src_tokens"][0]).index_select(0, new_order) ] if len(encoder_out["src_lengths"]) == 0: src_lengths = [] else: src_lengths = [ (encoder_out["src_lengths"][0]).index_select(0, new_order) ] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": src_tokens, # B x T "src_lengths": src_lengths, # B x 1 } def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict["{}.embed_positions._float_tensor".format( name)] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, feature_dict, embed_tokens, feature_embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.src_dict = dictionary self.feature_dict = feature_dict self.merging_method = args.feature_merge self.dropout = args.dropout self.encoder_layerdrop = args.encoder_layerdrop if self.merging_method == "concat": embed_dim = embed_tokens.embedding_dim +feature_embed_tokens.embedding_dim else: embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.feature_embed_tokens = feature_embed_tokens # self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) self.embed_scale = 1.0 self.embed_positions = ( PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [self.build_encoder_layer(args) for i in range(args.encoder_layers)] ) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if args.feature_merge =="gate": self.gate_layer = Linear(embed_dim*2, embed_dim) self.gate_sigmoid = torch.nn.Sigmoid() def build_encoder_layer(self, args): return TransformerEncoderLayer(args) def gate(self, x, sub_x): x_concat = torch.cat((x,sub_x),-1) context_gate = self.gate_sigmoid(self.gate_layer(x_concat)) return torch.add(context_gate * x, (1.- context_gate) * sub_x) def forward_embedding(self, src_tokens, feature_tokens, ): # embed tokens and positions if self.merging_method == "concat": x = embed = torch.cat((self.embed_tokens(src_tokens), self.feature_embed_tokens(feature_tokens)),-1) elif self.merging_method == "add": x = embed = torch.add(self.embed_tokens(src_tokens), self.feature_embed_tokens(feature_tokens)) # print('x', x.shape) elif self.merging_method == "gate": x = embed = self.gate(self.embed_tokens(src_tokens), self.feature_embed_tokens(feature_tokens)) else: x = embed =self.embed_tokens(src_tokens) # print("x", x.shape) if self.embed_positions is not None: if self.merging_method is not None: x = embed + self.embed_positions(embed[:,:,0]) else: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward(self, src_tokens, src_lengths, features =None, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ # print("features.shape",features.shape) #features.shape torch.Size([64, 44]) # exit() # for a,b in zip(src_tokens, features): # print(self.src_dict.string(a)) # print(self.feature_dict.string(b)) # print() # exit() # economic 활 력 둔화 로 growth 률 이 1분 기에 전 분기 compared - 0.3% 를 record 하고 , 각 institutions 이 year growth 전 망치 를 2% 초반 대로 낮추고 as 금리 increase 은 쉽지 않다는 analysis 이 지배 적 was . # <rep> <ori> <ori> <ori> <ori> <rep> <ori> <ori> <ori> <ori> <ori> <ori> <rep> <ori> <ori> <ori> <rep> <ori> <ori> <ori> <rep> <ori> <rep> <rep> <ori> <ori> <ori> <ori> <ori> <ori> <ori> <rep> <ori> <rep> <ori> <ori> <ori> <rep> <ori> <ori> <ori> <rep> <ori> # government 는 주@@ 씨 가 피@@ 랍@@ was since 외교부 와 국방부 , 국가정보원 을 on 으로 tf 를 consists by 리비아 government 와 u.s. 프랑스 , 영국 government as 과 공조 by 주@@ 씨 의 신변 safety 에 나선 has is . # <rep> <ori> <ori> <ori> <ori> <ori> <ori> <rep> <rep> <ori> <ori> <ori> <ori> <ori> <ori> <rep> <ori> <ori> <ori> <rep> <rep> <ori> <rep> <ori> <rep> <ori> <ori> <ori> <rep> <rep> <ori> <ori> <rep> <ori> <ori> <ori> <ori> <rep> <ori> <ori> <rep> <rep> <ori> x, encoder_embedding = self.forward_embedding(src_tokens, features) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, ) @torch.jit.export def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ """ Since encoder_padding_mask and encoder_embedding are both of type Optional[Tensor] in EncoderOut, they need to be copied as local variables for Torchscript Optional refinement """ encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding new_encoder_out = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) new_encoder_padding_mask = ( encoder_padding_mask if encoder_padding_mask is None else encoder_padding_mask.index_select(0, new_order) ) new_encoder_embedding = ( encoder_embedding if encoder_embedding is None else encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out, # T x B x C encoder_padding_mask=new_encoder_padding_mask, # B x T encoder_embedding=new_encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 ) def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SingleShotTransformerDecoder(FairseqEncoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout = args.dropout self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions # self.embed_tokens = embed_tokens self.output_embed_dim = args.decoder_output_dim self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) # self.embed_positions = ( # PositionalEmbedding( # args.max_source_positions, # embed_dim, # self.padding_idx, # learned=args.encoder_learned_pos, # ) # if not args.no_token_positional_embeddings # else None # ) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args) for i in range(args.encoder_layers) ]) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None self.project_out_dim = ( Linear(embed_dim, self.output_embed_dim, bias=False) if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None ) self.output_projection = nn.Linear( embed_tokens.weight.shape[1], embed_tokens.weight.shape[0], bias=False, ) self.output_projection.weight = embed_tokens.weight def build_encoder_layer(self, args): return TransformerEncoderLayer(args) def forward( self, x, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ # encoder layers for layer in self.layers: x = layer(x, None) if self.layer_norm is not None: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) return self.output_layer(x) def output_layer(self, features): """Project features to the vocabulary size.""" return self.output_projection(features) def max_positions(self): """Maximum output length supported by the decoder.""" return self.max_target_positions def buffered_future_mask(self, tensor): dim = tensor.size(0) # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. if ( self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) or self._future_mask.size(0) < dim ): self._future_mask = torch.triu( utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 ) self._future_mask = self._future_mask.to(tensor) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) if self.share_input_output_embed: embed_out_key = f"{name}.embed_tokens.weight" else: embed_out_key = f"{name}.embed_out" if embed_out_key in state_dict: state_dict[f"{name}.output_projection.weight"] = state_dict[embed_out_key] if not self.share_input_output_embed: del state_dict[embed_out_key] for i in range(self.num_layers): # update layer norms layer_norm_map = { "0": "self_attn_layer_norm", "1": "encoder_attn_layer_norm", "2": "final_layer_norm", } for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) if k in state_dict: state_dict[ "{}.layers.{}.{}.{}".format(name, i, new, m) ] = state_dict[k] del state_dict[k] version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerSentenceEncoder(nn.Module): """ Implementation for a Bi-directional Transformer based Sentence Encoder used in BERT/XLM style pre-trained models. This first computes the token embedding using the token embedding matrix, position embeddings (if specified) and segment embeddings (if specified). After applying the specified number of TransformerEncoderLayers, it outputs all the internal states of the encoder as well as the final representation associated with the first token (usually CLS token). Input: - tokens: B x T matrix representing sentences - segment_labels: B x T matrix representing segment label for tokens Output: - a tuple of the following: - a list of internal model states used to compute the predictions where each tensor has shape T x B x C - sentence representation associated with first input token in format B x C. """ def __init__( self, padding_idx: int, vocab_size: int, num_encoder_layers: int = 6, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, layerdrop: float = 0.0, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, offset_positions_by_padding: bool = True, encoder_normalize_before: bool = False, apply_bert_init: bool = False, activation_fn: str = "relu", learned_pos_embedding: bool = True, embed_scale: float = None, freeze_embeddings: bool = False, n_trans_layers_to_freeze: int = 0, export: bool = False, traceable: bool = False, q_noise: float = 0.0, qn_block_size: int = 8, ) -> None: super().__init__() self.padding_idx = padding_idx self.vocab_size = vocab_size self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__) self.layerdrop = layerdrop self.max_seq_len = max_seq_len self.embedding_dim = embedding_dim self.num_segments = num_segments self.use_position_embeddings = use_position_embeddings self.apply_bert_init = apply_bert_init self.learned_pos_embedding = learned_pos_embedding self.traceable = traceable self.embed_tokens = self.build_embedding(self.vocab_size, self.embedding_dim, self.padding_idx) self.embed_scale = embed_scale if q_noise > 0: self.quant_noise = apply_quant_noise_( nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), q_noise, qn_block_size, ) else: self.quant_noise = None self.segment_embeddings = (nn.Embedding( self.num_segments, self.embedding_dim, padding_idx=None) if self.num_segments > 0 else None) self.embed_positions = (PositionalEmbedding( self.max_seq_len, self.embedding_dim, padding_idx=( self.padding_idx if offset_positions_by_padding else None), learned=self.learned_pos_embedding, ) if self.use_position_embeddings else None) if encoder_normalize_before: self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) else: self.emb_layer_norm = None if self.layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_transformer_sentence_encoder_layer( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=self.dropout_module.p, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, export=export, q_noise=q_noise, qn_block_size=qn_block_size, ) for _ in range(num_encoder_layers) ]) # Apply initialization of model params after building the model if self.apply_bert_init: self.apply(init_bert_params) def freeze_module_params(m): if m is not None: for p in m.parameters(): p.requires_grad = False if freeze_embeddings: freeze_module_params(self.embed_tokens) freeze_module_params(self.segment_embeddings) freeze_module_params(self.embed_positions) freeze_module_params(self.emb_layer_norm) for layer in range(n_trans_layers_to_freeze): freeze_module_params(self.layers[layer]) def build_embedding(self, vocab_size, embedding_dim, padding_idx): return nn.Embedding(vocab_size, embedding_dim, padding_idx) def build_transformer_sentence_encoder_layer( self, embedding_dim, ffn_embedding_dim, num_attention_heads, dropout, attention_dropout, activation_dropout, activation_fn, export, q_noise, qn_block_size, ): return TransformerSentenceEncoderLayer( embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, export=export, q_noise=q_noise, qn_block_size=qn_block_size, ) def forward( self, tokens: torch.Tensor, segment_labels: torch.Tensor = None, last_state_only: bool = False, positions: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: is_tpu = tokens.device.type == "xla" # compute padding mask. This is needed for multi-head attention padding_mask = tokens.eq(self.padding_idx) if not self.traceable and not is_tpu and not padding_mask.any(): padding_mask = None if token_embeddings is not None: x = token_embeddings else: x = self.embed_tokens(tokens) if self.embed_scale is not None: x = x * self.embed_scale if self.embed_positions is not None: x = x + self.embed_positions(tokens, positions=positions) if self.segment_embeddings is not None and segment_labels is not None: x = x + self.segment_embeddings(segment_labels) if self.quant_noise is not None: x = self.quant_noise(x) if self.emb_layer_norm is not None: x = self.emb_layer_norm(x) x = self.dropout_module(x) # account for padding while computing the representation if padding_mask is not None: x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) inner_states = [] if not last_state_only: inner_states.append(x) for layer in self.layers: x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask) if not last_state_only: inner_states.append(x) sentence_rep = x[0, :, :] if last_state_only: inner_states = [x] if self.traceable: return torch.stack(inner_states), sentence_rep else: return inner_states, sentence_rep
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout = args.dropout self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_encoder_layer(args, i) for i in range(args.encoder_layers) ]) self.num_layers = len(self.layers) self.attention_window = args.attention_window if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None def build_encoder_layer(self, args, layer_id): return LongformerEncoderLayer(args, layer_id) def padding_src_tokens(self, src_tokens, _w): _device = src_tokens.device src_seqlen = src_tokens.size()[1] padded_src_seqlen = (int(src_seqlen//_w) + 1) * _w - src_seqlen padded_tensor = torch.zeros([src_tokens.size()[0],padded_src_seqlen], device = _device) + self.padding_idx res = torch.cat([src_tokens, padded_tensor.long()], axis = 1) return res def find_split_of_source_and_context(self, src_tokens): #looking for split index of the source in the input instance #there is only one <s> as the start (src_tokens_0), #and the </s> after that <s> is the end B=src_tokens.size()[0] src_tokens_start = torch.where(src_tokens==0)[1] src_tokens_end_tuple = torch.where(src_tokens==2) src_tokens_end_sen_index = src_tokens_end_tuple[0] src_tokens_end_index = src_tokens_end_tuple[1] src_tokens_end = torch.zeros(src_tokens_start.size()) for i in range(B): src_tokens_end_cur = src_tokens_end_index[torch.where(src_tokens_end_sen_index==i)] src_tokens_end[i] = src_tokens_end_cur[torch.where(src_tokens_end_cur > src_tokens_start[i])[0][0]] return src_tokens_start, src_tokens_end.int(), src_tokens_end_tuple def build_source_sentence_mask(self, src_tokens, src_tokens_start, src_tokens_end): _device = src_tokens.device B=src_tokens.size()[0] encoder_padding_mask = torch.zeros(src_tokens.size(), device = _device) for i in range(B): encoder_padding_mask[i, (1+src_tokens_start[i]): (1+src_tokens_end[i])] = 1 encoder_padding_mask = encoder_padding_mask < 0.5 return encoder_padding_mask def forward_embedding(self, src_tokens): # embed tokens and positions x = embed = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward( self, src_tokens, src_lengths, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ ## first find the start and end of the src sentence src_tokens_start, src_tokens_end, _ = self.find_split_of_source_and_context(src_tokens) seqlen = src_tokens.size()[1] max_window_size = max(self.attention_window) ##padding the input x with seqlen of 2*w's multiple, ##e.g. if w=32, let the seqlen to be 64's multiple. src_tokens = self.padding_src_tokens(src_tokens, 2*max_window_size) x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C #x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None encoder_attn = [] if return_all_hiddens else None # encoder layers for i, layer in enumerate(self.layers): x, _ = layer(x, output_attentions = True) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) assert encoder_attn is not None encoder_attn.append(_) ##only get the unpadded parts... x = x[:, 0:seqlen, :] encoder_padding_mask = self.build_source_sentence_mask(src_tokens[:, 0:seqlen], src_tokens_start, src_tokens_end) x = x.transpose(0, 1) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, encoder_attn=encoder_attn ) @torch.jit.export def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ new_encoder_out: Dict[str, Tensor] = {} new_encoder_out["encoder_out"] = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) new_encoder_out["encoder_padding_mask"] = ( encoder_out.encoder_padding_mask if encoder_out.encoder_padding_mask is None else encoder_out.encoder_padding_mask.index_select(0, new_order) ) new_encoder_out["encoder_embedding"] = ( encoder_out.encoder_embedding if encoder_out.encoder_embedding is None else encoder_out.encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out["encoder_out"], # T x B x C encoder_padding_mask=new_encoder_out["encoder_padding_mask"], # B x T encoder_embedding=new_encoder_out["encoder_embedding"], # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 encoder_attn = [] ) def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerDecoder(FairseqDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.decoder_layerdrop = args.decoder_layerdrop self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.embed_dim = embed_dim self.output_embed_dim = args.decoder_output_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None self.project_in_dim = (Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None) self.embed_positions = (PositionalEmbedding( self.max_target_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None self.cross_self_attention = getattr(args, "cross_self_attention", False) if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_decoder_layer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.num_layers = len(self.layers) if args.decoder_normalize_before and not getattr( args, "no_decoder_final_norm", False): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.project_out_dim = (Linear( embed_dim, self.output_embed_dim, bias=False) if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None) self.adaptive_softmax = None self.output_projection = None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif self.share_input_output_embed: self.output_projection = nn.Linear( self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False, ) self.output_projection.weight = self.embed_tokens.weight else: self.output_projection = nn.Linear(self.output_embed_dim, len(dictionary), bias=False) nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5) def build_decoder_layer(self, args, no_encoder_attn=False): layer = DecoderLayer(args.encoder_embed_dim, args.encoder_ffn_embed_dim, args.encoder_attention_heads, args.kdim, args.vdim, args.dropout) if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) return layer def forward( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features( prev_output_tokens, # 64 x 30 encoder_out= encoder_out, # từ điển chứa : encoder_out, encoder_padding_mask encoder_state # incremental_state=incremental_state, # full_context_alignment=full_context_alignment, # alignment_layer=alignment_layer, # alignment_heads=alignment_heads, ) return x, extra def extract_features( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]], # incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, # full_context_alignment: bool = False, # # alignment_layer: Optional[int] = None, # alignment_heads: Optional[int] = None, ): return self.extract_features_scriptable( prev_output_tokens, False, encoder_out, # incremental_state, # full_context_alignment, # # alignment_layer, # alignment_heads, ) """ A scriptable subclass of this class has an extract_features method and calls super().extract_features, but super() is not supported in torchscript. A copy of this function is made to be used in the subclass instead. """ def extract_features_scriptable( self, prev_output_tokens, return_attns, encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): """ Similar to *forward* but only return features. Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., EMNLP 2019). Args: full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). alignment_layer (int, optional): return mean alignment over heads at this layer (default: last layer). alignment_heads (int, optional): only average alignment over this many heads (default: all heads). Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ dec_slf_attn_list, dec_enc_attn_list = [], [] # -- Prepare masks non_pad_mask = get_non_pad_mask(prev_output_tokens, self.padding_idx) slf_attn_mask_subseq = get_subsequent_mask(prev_output_tokens) slf_attn_mask_keypad = get_attn_key_pad_mask( seq_k=prev_output_tokens, seq_q=prev_output_tokens, padding_idx=self.padding_idx) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) dec_enc_attn_mask = get_attn_key_pad_mask( seq_k=encoder_out['src_tokens'][0], seq_q=prev_output_tokens, padding_idx=self.padding_idx) dec_output_phase = self.embed_positions(prev_output_tokens) # embed tokens and positions dec_output_real_real = self.embed_tokens(prev_output_tokens) tgt_pos = get_position(prev_output_tokens) # pos = torch.unsqueeze(torch.LongTensor(tgt_pos), 2) pos = torch.unsqueeze(torch.cuda.LongTensor(tgt_pos), 2) dec_output_phase = torch.mul(pos.float(), dec_output_phase) # size = 64 x 51 x 512 cos = torch.cos(dec_output_phase).to() sin = torch.sin(dec_output_phase) # f(j, pos) = f_we(j) (.) f_pe(j, pos) --> Nhan element-wise dec_output_real = dec_output_real_real * cos dec_output_phase = dec_output_real_real * sin # decoder layers for dec_layer in self.layers: dec_output_real, dec_output_phase, dec_slf_attn, dec_enc_attn = dec_layer( dec_output_real, dec_output_phase, encoder_out["enc_output_real"][0].permute(1, 0, 2), encoder_out['enc_output_phase'][0].permute(1, 0, 2), non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) if return_attns: dec_slf_attn_list += [dec_slf_attn] dec_enc_attn_list += [dec_enc_attn] dec_output = dec_output_real * dec_output_real + dec_output_phase * dec_output_phase # size = (64, 51, 512) seq_logit = self.output_projection(dec_output) return seq_logit, None def output_layer(self, features): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary return self.output_projection(features) else: return features def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions) def buffered_future_mask(self, tensor): dim = tensor.size(0) # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. if (self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) or self._future_mask.size(0) < dim): self._future_mask = torch.triu( utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1) self._future_mask = self._future_mask.to(tensor) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict["{}.embed_positions._float_tensor".format( name)] = torch.FloatTensor(1) if f"{name}.output_projection.weight" not in state_dict: if self.share_input_output_embed: embed_out_key = f"{name}.embed_tokens.weight" else: embed_out_key = f"{name}.embed_out" if embed_out_key in state_dict: state_dict[f"{name}.output_projection.weight"] = state_dict[ embed_out_key] if not self.share_input_output_embed: del state_dict[embed_out_key] for i in range(self.num_layers): # update layer norms layer_norm_map = { "0": "self_attn_layer_norm", "1": "encoder_attn_layer_norm", "2": "final_layer_norm", } for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layers.{}.layer_norms.{}.{}".format( name, i, old, m) if k in state_dict: state_dict["{}.layers.{}.{}.{}".format( name, i, new, m)] = state_dict[k] del state_dict[k] version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SpeechTransformerDecoderBase(TransformerDecoderBase): def __init__( self, cfg, dictionary, embed_tokens, no_encoder_attn=False, output_projection=None, scheduled_sampling_rate_scheduler=None, ): is_no_token_positional_embeddings_changed = False if (not cfg.no_token_positional_embeddings and cfg.decoder.relative_positional_embeddings): cfg.no_token_positional_embeddings = True is_no_token_positional_embeddings_changed = True logger.info( "disabled decoder's absolute positional embeddings as decoder_relative_positional_embeddings is True." ) self.cfg = cfg super(TransformerDecoderBase, self).__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) self.dropout_module = FairseqDropout( cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)) self.decoder_layerdrop = cfg.decoder.layerdrop self.share_input_output_embed = cfg.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = cfg.decoder.embed_dim self.embed_dim = embed_dim self.output_embed_dim = cfg.decoder.output_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = cfg.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt( embed_dim) if not cfg.adaptive_input and cfg.quant_noise.pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), cfg.quant_noise.pq, cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None self.project_in_dim = (Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None) self.embed_positions = (PositionalEmbedding( self.max_target_positions, embed_dim, self.padding_idx, learned=cfg.decoder.learned_pos, ) if not cfg.no_token_positional_embeddings else None) if cfg.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None self.cross_self_attention = cfg.cross_self_attention if cfg.decoder.relative_positional_embeddings: if cfg.decoder.learned_pos: rel_pos_embed_list = [ RelativePositionalEmbedding( cfg.decoder.embed_dim, padding_idx=None, max_size=cfg.max_target_positions, learned=True, ) for _ in range(cfg.decoder.layers) ] else: rel_pos_embed = RelativePositionalEmbedding( cfg.decoder.embed_dim, padding_idx=None, max_size=None, learned=False, ) # single instance referenced across layers rel_pos_embed_list = [rel_pos_embed] * cfg.decoder.layers else: rel_pos_embed_list = [None] * cfg.decoder.layers if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_decoder_layer( cfg, no_encoder_attn, positional_embedding=rel_pos_embed_list[i]) for i in range(cfg.decoder.layers) ]) self.num_layers = len(self.layers) if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm: self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None self.project_out_dim = (Linear( embed_dim, self.output_embed_dim, bias=False) if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights else None) self.adaptive_softmax = None self.output_projection = output_projection if self.output_projection is None: self.build_output_projection(cfg, dictionary, embed_tokens) if is_no_token_positional_embeddings_changed: cfg.no_token_positional_embeddings = not cfg.no_token_positional_embeddings self.scheduled_sampling_rate_scheduler = scheduled_sampling_rate_scheduler for layer in self.layers: if isinstance( layer, TransformerWithRelativePositionalEmbeddingDecoderLayerBase ): layer.need_attn = False # make validation fast def build_decoder_layer( self, cfg, no_encoder_attn=False, positional_embedding: Optional[RelativePositionalEmbedding] = None, ): layer = TransformerWithRelativePositionalEmbeddingDecoderLayerBase( cfg, no_encoder_attn=no_encoder_attn, positional_embedding=positional_embedding, ) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, **kwargs, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ if (self.training and alignment_layer is None): # no attention tensors during training to save memory alignment_layer = self.num_layers # can be any value no less than this if self.training and self.scheduled_sampling_rate_scheduler is not None: epoch = kwargs.get("epoch", 1) sampling_prob = self.scheduled_sampling_rate_scheduler.step(epoch) if sampling_prob < 1.0: # apply scheduled sampling assert not features_only return self._forward_with_scheduled_sampling( prev_output_tokens, sampling_prob, encoder_out=encoder_out, incremental_state= {}, # use empty dict to preserve forward state full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens, ) x, extra = self.extract_features( prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) if not features_only: x = self.output_layer(x) return x, extra def _forward_with_scheduled_sampling( self, prev_output_tokens, sampling_prob, encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, ): bsz, seqlen = prev_output_tokens.size() outs = [] pred = None for step in range(seqlen): if step > 0: sampling_mask = torch.rand( [bsz, 1], device=prev_output_tokens.device, ).lt(sampling_prob) feed_tokens = torch.where( sampling_mask, prev_output_tokens[:, step:step + 1], pred, ) else: feed_tokens = prev_output_tokens[:, step:step + 1] # B x 1 x, _ = self.extract_features( feed_tokens, encoder_out=encoder_out, incremental_state=incremental_state, full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) x = self.output_layer(x) # B x 1 x V outs.append(x) pred = x.argmax(-1) # B x 1 x = torch.cat(outs, dim=1) # B x T x V return x, None def initialize_cached_state( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): bsz = prev_output_tokens.size(0) x = self.embed_tokens(prev_output_tokens) num_heads = self.cfg.decoder.attention_heads embed_dim = self.cfg.decoder.embed_dim zero_states = x.new_zeros(len(self.layers), bsz, num_heads, 0, embed_dim // num_heads) cache_state = torch.jit.annotate( Dict[str, Optional[Tensor]], { "prev_key": zero_states, "prev_value": zero_states, }, ) self.set_incremental_state(incremental_state, "cached_state", cache_state) def get_cached_state( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], ) -> Tuple[List[Tensor], List[Tensor], Optional[List[Tensor]]]: prev_key, prev_value, prev_key_padding_mask = [], [], [] for layer in self.layers: attn_state = layer.self_attn.get_incremental_state( incremental_state, "attn_state") assert attn_state is not None assert attn_state["prev_key"] is not None prev_key.append(attn_state["prev_key"]) assert attn_state["prev_value"] is not None prev_value.append(attn_state["prev_value"]) if len(attn_state ) >= 3 and attn_state["prev_key_padding_mask"] is not None: prev_key_padding_mask.append( attn_state["prev_key_padding_mask"]) if len(prev_key_padding_mask) == 0: prev_key_padding_mask = None else: assert len(prev_key_padding_mask) == len(prev_key) return prev_key, prev_value, prev_key_padding_mask def get_incremental_state( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, ) -> Optional[Dict[str, Optional[Tensor]]]: if key != "cached_state": return super().get_incremental_state(incremental_state, key) if incremental_state is None: return None prev_key, prev_value, prev_key_padding_mask = self.get_cached_state( self, incremental_state) cached_state: Dict[str, Optional[Tensor]] = { "prev_key": torch.stack(prev_key), "prev_value": torch.stack(prev_value), } if prev_key_padding_mask is not None: cached_state["prev_key_padding_mask"] = torch.stack( prev_key_padding_mask) return cached_state def set_incremental_state( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, value: Dict[str, Optional[Tensor]], ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: if key != "cached_state": return super().set_incremental_state(incremental_state, key, value) if incremental_state is not None: pre_key = value["pre_key"] pre_value = value["pre_value"] for i, layer in enumerate(self.layers): attn_state: Dict[str, Optional[Tensor]] = { "prev_key": pre_key[i] if pre_key is not None else None, "prev_value": pre_value[i] if pre_value is not None else None, } if len(value) >= 3: prev_key_padding_mask = value["prev_key_padding_mask"] attn_state["prev_key_padding_mask"] = ( prev_key_padding_mask[i] if prev_key_padding_mask is not None else None) layer.self_attn.set_incremental_state(incremental_state, "attn_state", attn_state) return incremental_state def masked_copy_cached_state( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], src_cached_state: Tuple[Optional[Union[List[torch.Tensor], torch.Tensor]]], mask: Tensor, ): if incremental_state is None: assert src_cached_state is None or len(src_cached_state) == 0 return prev_key, prev_value, prev_key_padding_mask = self.get_cached_state( incremental_state) src_prev_key, src_prev_value, src_prev_key_padding_mask = ( src_cached_state[0], src_cached_state[1], src_cached_state[2], ) # pad one more time step, assuming src_cached_state is just one step behind src_prev_key = [F.pad(src_p, (0, 0, 0, 1)) for src_p in src_prev_key] src_prev_value = [ F.pad(src_p, (0, 0, 0, 1)) for src_p in src_prev_value ] if src_prev_key_padding_mask is not None: src_prev_key_padding_mask = [ F.pad(src_p, (0, 1)) for src_p in src_prev_key_padding_mask ] def masked_copy_state(state: Optional[Tensor], src_state: Optional[Tensor]): if state is None: assert src_state is None return None else: assert (state.size(0) == mask.size(0) and src_state is not None and state.size() == src_state.size()) state[mask, ...] = src_state[mask, ...] return state prev_key = [ masked_copy_state(p, src_p) for (p, src_p) in zip(prev_key, src_prev_key) ] prev_value = [ masked_copy_state(p, src_p) for (p, src_p) in zip(prev_value, src_prev_value) ] if prev_key_padding_mask is None: prev_key_padding_mask = src_prev_key_padding_mask else: assert src_prev_key_padding_mask is not None prev_key_padding_mask = [ masked_copy_state(p, src_p) for (p, src_p ) in zip(prev_key_padding_mask, src_prev_key_padding_mask) ] cached_state = torch.jit.annotate( Dict[str, Optional[Tensor]], { "prev_key": torch.stack(prev_key), "prev_value": torch.stack(prev_value), }, ) if prev_key_padding_mask is not None: cached_state["prev_key_padding_mask"] = torch.stack( prev_key_padding_mask) self.set_incremental_state(incremental_state, "cached_state", cached_state)
class LunaSentenceEncoder(nn.Module): """ Implementation for a Bi-directional Luna based Sentence Encoder used in masked pre-trained language models. This first computes the token embedding using the token embedding matrix, position embeddings (if specified) and segment embeddings (if specified). After applying the specified number of TransformerEncoderLayers, it outputs all the internal states of the encoder as well as the final representation associated with the first token (usually CLS token). Input: - tokens: B x T matrix representing sentences - segment_labels: B x T matrix representing segment label for tokens Output: - a tuple of the following: - a list of internal model states used to compute the predictions where each tensor has shape T x B x C - sentence representation associated with first input token in format B x C. """ def __init__( self, padding_idx: int, vocab_size: int, projection_length: int = 128, num_encoder_layers: int = 12, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 12, num_projected_attention_heads: int = 12, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, layerdrop: float = 0.0, max_seq_len: int = 512, num_segments: int = 0, use_position_embeddings: bool = True, offset_positions_by_padding: bool = True, layernorm_embedding: bool = False, normalize_before: bool = False, dynamic_projection: bool = True, tie_kv=False, apply_bert_init: bool = False, activation_fn: str = "gelu", learned_pos_embedding: bool = True, embed_scale: float = None, freeze_embeddings: bool = False, n_trans_layers_to_freeze: int = 0, export: bool = False, traceable: bool = False, ) -> None: super().__init__() self.padding_idx = padding_idx self.vocab_size = vocab_size self.proj_len = projection_length self.dynamic_projection = dynamic_projection self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__) self.layerdrop = layerdrop self.max_seq_len = max_seq_len self.embedding_dim = embedding_dim self.num_segments = num_segments self.use_position_embeddings = use_position_embeddings self.apply_bert_init = apply_bert_init self.learned_pos_embedding = learned_pos_embedding self.traceable = traceable self.tpu = False # whether we're on TPU self.embed_tokens = self.build_embedding(self.vocab_size, self.embedding_dim, self.padding_idx) self.embed_scale = embed_scale if self.num_segments > 0: self.segment_embeddings = nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None) nn.init.normal_(self.segment_embeddings.weight, mean=0.0, std=self.embedding_dim**-0.5) else: self.segment_embeddings = None self.embed_positions = (PositionalEmbedding( self.max_seq_len, self.embedding_dim, padding_idx=( self.padding_idx if offset_positions_by_padding else None), learned=self.learned_pos_embedding, ) if self.use_position_embeddings else None) self.projected_embeddings = Parameter( torch.Tensor(self.proj_len, self.embedding_dim)) nn.init.normal_(self.projected_embeddings, mean=0.0, std=self.embedding_dim**-0.5) if self.use_position_embeddings and not self.learned_pos_embedding: projected_positions = get_sinusoidal_positional_embedding( self.proj_len, self.embedding_dim) if self.embed_scale is None: self.embed_scale = math.sqrt(self.embedding_dim) else: projected_positions = None self.register_buffer("projected_positions", projected_positions) if self.layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_luna_sentence_encoder_layer( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, num_projected_attention_heads=num_projected_attention_heads, dropout=self.dropout_module.p, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, normalize_before=normalize_before, tie_kv=tie_kv, export=export, ) for _ in range(num_encoder_layers) ]) assert not layernorm_embedding or not normalize_before if layernorm_embedding: self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) self.proj_emb_layer_norm = LayerNorm(self.embedding_dim, export=export) else: self.emb_layer_norm = None self.proj_emb_layer_norm = None if normalize_before: self.layer_norm = LayerNorm(self.embedding_dim, export=export) self.proj_layer_norm = LayerNorm(self.embedding_dim, export=export) else: self.layer_norm = None self.proj_layer_norm = None # Apply initialization of model params after building the model if self.apply_bert_init: self.apply(init_bert_params) def freeze_module_params(m): if m is not None: for p in m.parameters(): p.requires_grad = False if freeze_embeddings: self.projected_embeddings.requires_grad = False freeze_module_params(self.embed_tokens) freeze_module_params(self.segment_embeddings) freeze_module_params(self.embed_positions) freeze_module_params(self.emb_layer_norm) freeze_module_params(self.proj_emb_layer_norm) for layer in range(n_trans_layers_to_freeze): freeze_module_params(self.layers[layer]) log_class_usage(__class__) def build_embedding(self, vocab_size, embedding_dim, padding_idx): embed_tokens = nn.Embedding(vocab_size, embedding_dim, padding_idx) nn.init.normal_(embed_tokens.weight, mean=0, std=embedding_dim**-0.5) return embed_tokens def build_luna_sentence_encoder_layer( self, embedding_dim, ffn_embedding_dim, num_attention_heads, num_projected_attention_heads, dropout, attention_dropout, activation_dropout, activation_fn, normalize_before, tie_kv, export, q_noise, qn_block_size, ): return LunaSentenceEncoderLayer( embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, num_projected_attention_heads=num_projected_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, normalize_before=normalize_before, tie_kv=tie_kv, export=export, q_noise=q_noise, qn_block_size=qn_block_size, ) def prepare_for_tpu_(self, **kwargs): self.tpu = True def forward( self, tokens: torch.Tensor, segment_labels: torch.Tensor = None, last_state_only: bool = False, positions: Optional[torch.Tensor] = None, ): # compute padding mask. This is needed for multi-head attention # B x T x_padding_mask = tokens.eq(self.padding_idx) lengths = tokens.size(1) - x_padding_mask.sum(1) max_len = lengths.max() if self.dynamic_projection else self.proj_len x = self.embed_tokens(tokens) px = self.projected_embeddings[:max_len] if self.embed_scale is not None: x *= self.embed_scale px *= self.embed_scale if self.embed_positions is not None: x += self.embed_positions(tokens, positions=positions) if self.projected_positions is not None: px += self.projected_positions[:max_len] if self.segment_embeddings is not None and segment_labels is not None: x += self.segment_embeddings(segment_labels) if self.quant_noise is not None: x = self.quant_noise(x) if self.emb_layer_norm is not None: x = self.emb_layer_norm(x) px = self.proj_emb_layer_norm(px) bsz = x.size(0) len, dim = px.size() # L x C -> B x L x C px = px.unsqueeze(0).expand(bsz, len, dim) if self.dynamic_projection: pidx = torch.arange(len).unsqueeze(0).to(x.device) # B x L px_padding_mask = pidx.ge(lengths.unsqueeze(1)) else: px_padding_mask = None if not self.traceable and not self.tpu: if not x_padding_mask.any(): x_padding_mask = None if px_padding_mask is not None and not px_padding_mask.any(): px_padding_mask = None x = self.dropout_module(x) px = self.dropout_module(px) # account for padding while computing the representation if x_padding_mask is not None: x = x * (1 - x_padding_mask.unsqueeze(-1).type_as(x)) if px_padding_mask is not None: px = px * (1 - px_padding_mask.unsqueeze(-1).type_as(px)) # B x T x C -> T x B x C x = x.transpose(0, 1) # B x L x C -> L x B x C px = px.transpose(0, 1) inner_states = [] if not last_state_only: inner_states.append(x) for layer in self.layers: x, px, _ = layer(x, px, x_padding_mask=x_padding_mask, px_padding_mask=px_padding_mask) if not last_state_only: inner_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) px = self.proj_layer_norm(px) # sentence_cls_rep = x[0, :, :] # sentence_proj_rep = px if last_state_only: inner_states = [x] return inner_states
class LunaEncoder(FairseqEncoder): """ Luna encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`LunaEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) self.dropword_module = FairseqFeatureDropout(args.word_dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim assert embed_dim == args.encoder_embed_dim, 'encoder embedding dim mismatch.' self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) assert not args.layernorm_embedding or not args.encoder_normalize_before if args.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim) self.layernorm_porjected_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None self.layernorm_porjected_embedding = None self.proj_len = args.projection_length self.dynamic_projection = not args.fix_projection_length self.projected_embeddings = Parameter(torch.Tensor(self.proj_len, embed_dim)) nn.init.normal_(self.projected_embeddings, mean=0., std=embed_dim ** -0.5) if not args.no_token_positional_embeddings and not args.encoder_learned_pos: projected_positions = get_sinusoidal_positional_embedding(self.proj_len, embed_dim) else: projected_positions = None self.register_buffer('projected_positions', projected_positions) if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([self.build_encoder_layer(i, args) for i in range(args.encoder_layers)]) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) self.proj_layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.proj_layer_norm = None def build_encoder_layer(self, layer_id, args): return LunaEncoderLayer(args, layer_id) def forward_embedding(self, src_tokens): # embed tokens and positions x = embed = self.embed_scale * self.embed_tokens(src_tokens) x = self.dropword_module(x) if self.embed_positions is not None: x = x + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) return x, embed def forward_projected_embedding(self, src_lengths): max_len = src_lengths.max() if self.dynamic_projection else self.proj_len px = proj_embed = self.embed_scale * self.projected_embeddings[:max_len] if self.projected_positions is not None: px = px + self.projected_positions[:max_len] if self.layernorm_porjected_embedding is not None: px = self.layernorm_porjected_embedding(px) return px, proj_embed def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens) px, projected_embedding = self.forward_projected_embedding(src_lengths) bsz = x.size(0) len, dim = px.size() # B x T x C -> T x B x C x = x.transpose(0, 1) # L x C -> L x B x C px = px.unsqueeze(1).expand(len, bsz, dim) x = self.dropout_module(x) px = self.dropout_module(px) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if self.dynamic_projection: pidx= torch.arange(len).unsqueeze(0).to(x.device) encoder_projected_padding_mask = pidx.ge(src_lengths.unsqueeze(1)) else: encoder_projected_padding_mask = None encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: x, px = layer(x, px, encoder_padding_mask, encoder_projected_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append((x, px)) if self.layer_norm is not None: x = self.layer_norm(x) px = self.proj_layer_norm(px) return EncoderOut( encoder_out=x, # T x B x C encoder_projected_out=px, # L x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_projected_padding_mask=encoder_projected_padding_mask, # B x L encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, ) @torch.jit.export def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ """ Since encoder_padding_mask and encoder_embedding are both of type Optional[Tensor] in EncoderOut, they need to be copied as local variables for Torchscript Optional refinement """ encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask encoder_projected_padding_mask: Optional[Tensor] = encoder_out.encoder_projected_padding_mask encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding new_encoder_out = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) new_encoder_projected_out = ( encoder_out.encoder_projected_out if encoder_out.encoder_projected_out is None else encoder_out.encoder_projected_out.index_select(1, new_order) ) new_encoder_padding_mask = ( encoder_padding_mask if encoder_padding_mask is None else encoder_padding_mask.index_select(0, new_order) ) new_encoder_projected_padding_mask = ( encoder_projected_padding_mask if encoder_projected_padding_mask is None else encoder_projected_padding_mask.index_select(0, new_order) ) new_encoder_embedding = ( encoder_embedding if encoder_embedding is None else encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out, # T x B x C encoder_projected_out=new_encoder_projected_out, # L x B x C encoder_padding_mask=new_encoder_padding_mask, # B x T encoder_projected_padding_mask=new_encoder_projected_padding_mask, # B x L encoder_embedding=new_encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 ) def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SingleShotTransformerDecoder(TransformerDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.args = args super().__init__(args, dictionary, embed_tokens) self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) self.dropout = args.dropout self.decoder_layerdrop = args.decoder_layerdrop self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.embed_dim = embed_dim self.output_embed_dim = args.decoder_output_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None self.project_in_dim = ( Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None ) self.embed_positions = ( PositionalEmbedding( args.max_target_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None self.cross_self_attention = getattr(args, "cross_self_attention", False) if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend([ self.build_decoder_layer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.num_layers = len(self.layers) if args.decoder_normalize_before and not getattr(args, "no_decoder_final_norm", False): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.project_out_dim = ( Linear(embed_dim, self.output_embed_dim, bias=False) if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None ) self.adaptive_softmax = None self.output_projection = None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif self.share_input_output_embed: self.output_projection = nn.Linear( self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False, ) self.output_projection.weight = self.embed_tokens.weight else: self.output_projection = nn.Linear( self.output_embed_dim, len(dictionary), bias=False ) nn.init.normal_( self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 ) def build_decoder_layer(self, args, no_encoder_attn=False): return SingleShotTransformerDecoderLayer(args, no_encoder_attn=no_encoder_attn) def forward( self, prev_output_tokens, encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, q_raws = self.extract_features( prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) if not features_only: x = self.output_layer(x) return x, q_raws def extract_features( self, prev_output_tokens, encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): """ Similar to *forward* but only return features. Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., EMNLP 2019). Args: full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). alignment_layer (int, optional): return mean alignment over heads at this layer (default: last layer). alignment_heads (int, optional): only average alignment over this many heads (default: all heads). Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ if alignment_layer is None: alignment_layer = self.num_layers - 1 # embed positions positions = ( self.embed_positions( prev_output_tokens, incremental_state=incremental_state ) if self.embed_positions is not None else None ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.quant_noise is not None: x = self.quant_noise(x) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) self_attn_padding_mask: Optional[Tensor] = None if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # decoder layers attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] q_raws: List[Optional[Tensor]] = [] for idx, layer in enumerate(self.layers): # if incremental_state is None and not full_context_alignment: # self_attn_mask = self.buffered_future_mask(x) # else: self_attn_mask = None x, layer_attn, q_raw = layer( x, encoder_out.encoder_states[idx] if encoder_out is not None else None, encoder_out.encoder_padding_mask if encoder_out is not None else None, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer)), ) q_raws.append(q_raw) inner_states.append(x) if layer_attn is not None and idx == alignment_layer: attn = layer_attn.float().to(x) if attn is not None: if alignment_heads is not None: attn = attn[:alignment_heads] # average probabilities over heads attn = attn.mean(dim=0) if self.layer_norm is not None: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) return x, q_raws #{"attn": [attn], "inner_states": inner_states}
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) self.embed_positions = ( PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [self.build_encoder_layer(args) for i in range(args.encoder_layers)] ) self.num_layers = len(self.layers) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.layer_wise_attention = args.layer_wise_attn # for T5 attention self.rel_pos = args.rel_pos self.relative_attention_num_buckets = args.relative_attention_num_buckets self.max_rel_dist = args.max_rel_pos self.ignore_direction = args.ignore_direction self.num_heads = args.encoder_attention_heads if self.rel_pos: max_dist = self.relative_attention_num_buckets + 1 \ if self.ignore_direction else self.relative_attention_num_buckets * 2 + 1 self.relative_attention_bias = nn.Embedding(max_dist, self.num_heads) def build_encoder_layer(self, args): return TransformerEncoderLayer(args) def forward_embedding(self, src_tokens): # embed tokens and positions x = embed = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ if self.layer_wise_attention: return_all_hiddens = True x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None position_bias = None if self.rel_pos: seq_len, bsz, _ = x.size() position_bias = self.compute_bias(seq_len, seq_len) # (1, n_heads, qlen, klen) position_bias = position_bias.repeat(bsz, 1, 1, 1) position_bias = position_bias.view(bsz * self.num_heads, seq_len, seq_len) # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask, position_bias=position_bias) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, ) @torch.jit.export def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ """ Since encoder_padding_mask and encoder_embedding are both of type Optional[Tensor] in EncoderOut, they need to be copied as local variables for Torchscript Optional refinement """ encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding new_encoder_out = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) new_encoder_padding_mask = ( encoder_padding_mask if encoder_padding_mask is None else encoder_padding_mask.index_select(0, new_order) ) new_encoder_embedding = ( encoder_embedding if encoder_embedding is None else encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out, # T x B x C encoder_padding_mask=new_encoder_padding_mask, # B x T encoder_embedding=new_encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 ) def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict def compute_bias(self, qlen, klen): """ Compute binned relative position bias """ context_position = torch.arange(qlen, dtype=torch.long)[:, None] memory_position = torch.arange(klen, dtype=torch.long)[None, :] relative_position = memory_position - context_position # shape (qlen, klen) distance_mat_clipped = torch.clamp(relative_position, min=-self.relative_attention_num_buckets, max=self.relative_attention_num_buckets) if self.ignore_direction: rp_bucket = torch.abs(distance_mat_clipped) else: rp_bucket = distance_mat_clipped + self.relative_attention_num_buckets rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) return values