class CharCNNEncoder(FairseqEncoder): """ Character-level CNN encoder to generate word representations, as input to transformer encoder. """ def __init__( self, args, dictionary, embed_tokens, num_chars=50, embed_dim=32, char_cnn_params="[(128, 3), (128, 5)]", char_cnn_nonlinear_fn="tanh", char_cnn_pool_type="max", char_cnn_num_highway_layers=0, char_cnn_output_dim=-1, use_pretrained_weights=False, finetune_pretrained_weights=False, weights_file=None, ): super().__init__(dictionary) convolutions_params = literal_eval(char_cnn_params) self.char_cnn_encoder = char_encoder.CharCNNModel( dictionary, num_chars, embed_dim, convolutions_params, char_cnn_nonlinear_fn, char_cnn_pool_type, char_cnn_num_highway_layers, char_cnn_output_dim, use_pretrained_weights, finetune_pretrained_weights, weights_file, ) self.embed_tokens = embed_tokens token_embed_dim = embed_tokens.embedding_dim self.word_layer_norm = nn.LayerNorm(token_embed_dim) char_embed_dim = ( char_cnn_output_dim if char_cnn_output_dim != -1 else sum(out_dim for (out_dim, _) in convolutions_params) ) self.char_layer_norm = nn.LayerNorm(char_embed_dim) self.word_dim = char_embed_dim + token_embed_dim self.char_scale = math.sqrt(char_embed_dim / self.word_dim) self.word_scale = math.sqrt(token_embed_dim / self.word_dim) if self.word_dim != args.encoder_embed_dim: self.word_to_transformer_embed = fairseq_transformer.Linear( self.word_dim, args.encoder_embed_dim ) self.dropout = args.dropout self.padding_idx = dictionary.pad() self.embed_positions = fairseq_transformer.PositionalEmbedding( 1024, args.encoder_embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) self.transformer_encoder_given_embeddings = TransformerEncoderGivenEmbeddings( args=args, proj_to_decoder=True ) # Variable tracker self.tracker = VariableTracker() # Initialize adversarial mode self.set_gradient_tracking_mode(False) self.set_embed_noising_mode(False) # disables sorting and word-length thresholding if True # (enables ONNX tracing of length-sorted input with batch_size = 1) self.onnx_export_model = False def prepare_for_onnx_export_(self): self.onnx_export_model = True def set_gradient_tracking_mode(self, mode=True): """ This allows AdversarialTrainer to turn on retrain_grad when running adversarial example generation model.""" self.tracker.reset() self.track_gradients = mode def set_embed_noising_mode(self, mode=True): """This allows adversarial trainer to turn on and off embedding noising layers. In regular training, this mode is off, and it is not included in forward pass. """ self.embed_noising_mode = mode def forward(self, src_tokens, src_lengths, char_inds, word_lengths): self.tracker.reset() # char_inds has shape (batch_size, max_words_per_sent, max_word_len) bsz, seqlen, maxchars = char_inds.size() # char_cnn_encoder takes input (max_word_length, total_words) char_inds_flat = char_inds.view(-1, maxchars).t() # output (total_words, encoder_dim) char_cnn_output = self.char_cnn_encoder(char_inds_flat) x = char_cnn_output.view(bsz, seqlen, char_cnn_output.shape[-1]) x = x.transpose(0, 1) # (seqlen, bsz, char_cnn_output_dim) x = self.char_layer_norm(x) x = self.char_scale * x embedded_tokens = self.embed_tokens(src_tokens) # (seqlen, bsz, token_embed_dim) embedded_tokens = embedded_tokens.transpose(0, 1) embedded_tokens = self.word_layer_norm(embedded_tokens) embedded_tokens = self.word_scale * embedded_tokens x = torch.cat([x, embedded_tokens], dim=2) self.tracker.track(x, "token_embeddings", retain_grad=self.track_gradients) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.word_to_transformer_embed is not None: x = self.word_to_transformer_embed(x) positions = self.embed_positions(src_tokens) x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask (B x T) encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None x = self.transformer_encoder_given_embeddings( x=x, positions=positions, encoder_padding_mask=encoder_padding_mask ) if self.onnx_export_model and encoder_padding_mask is None: encoder_padding_mask = torch.Tensor([]).type_as(src_tokens) return x, src_tokens, encoder_padding_mask def reorder_encoder_out(self, encoder_out, new_order): (x, src_tokens, encoder_padding_mask) = encoder_out if x is not None: x = x.index_select(1, new_order) if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) if encoder_padding_mask is not None: encoder_padding_mask = encoder_padding_mask.index_select(0, new_order) return (x, src_tokens, encoder_padding_mask) def max_positions(self): """Maximum input length supported by the encoder.""" return self.embed_positions.max_positions() def upgrade_state_dict(self, state_dict): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if "encoder.embed_positions.weights" in state_dict: del state_dict["encoder.embed_positions.weights"] state_dict["encoder.embed_positions._float_tensor"] = torch.FloatTensor(1) return state_dict
class LSTMSequenceEncoder(FairseqEncoder): """RNN encoder using nn.LSTM for cuDNN support / ONNX exportability.""" @staticmethod def LSTM(input_size, hidden_size, **kwargs): m = nn.LSTM(input_size, hidden_size, **kwargs) for name, param in m.named_parameters(): if "weight" in name or "bias" in name: param.data.uniform_(-0.1, 0.1) return m def __init__( self, dictionary, embed_dim=512, freeze_embed=False, cell_type="lstm", hidden_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, residual_level=None, bidirectional=False, pretrained_embed=None, word_dropout_params=None, padding_value=0, left_pad=True, ): assert cell_type == "lstm", 'sequence-lstm requires cell_type="lstm"' super().__init__(dictionary) self.dictionary = dictionary self.dropout_in = dropout_in self.dropout_out = dropout_out self.residual_level = residual_level self.hidden_dim = hidden_dim self.bidirectional = bidirectional num_embeddings = len(dictionary) self.padding_idx = dictionary.pad() self.padding_value = padding_value self.left_pad = left_pad self.embed_tokens = Embedding( num_embeddings=num_embeddings, embedding_dim=embed_dim, padding_idx=self.padding_idx, freeze_embed=freeze_embed, ) pytorch_translate_utils.load_embedding( embedding=self.embed_tokens, dictionary=dictionary, pretrained_embed=pretrained_embed, ) self.word_dim = embed_dim self.layers = nn.ModuleList([]) for layer in range(num_layers): is_layer_bidirectional = self.bidirectional and layer == 0 self.layers.append( LSTMSequenceEncoder.LSTM( self.word_dim if layer == 0 else hidden_dim, hidden_dim // 2 if is_layer_bidirectional else hidden_dim, num_layers=1, dropout=self.dropout_out, bidirectional=is_layer_bidirectional, )) self.num_layers = len(self.layers) self.word_dropout_module = None if (word_dropout_params and word_dropout_params["word_dropout_freq_threshold"] is not None and word_dropout_params["word_dropout_freq_threshold"] > 0): self.word_dropout_module = word_dropout.WordDropout( dictionary, word_dropout_params) # Variable tracker self.tracker = VariableTracker() # Initialize adversarial mode self.set_gradient_tracking_mode(False) def forward(self, src_tokens, src_lengths): if self.left_pad: # convert left-padding to right-padding src_tokens = utils.convert_padding_direction(src_tokens, self.padding_idx, left_to_right=True) # If we're generating adversarial examples we need to keep track of # some internal variables self.tracker.reset() if self.word_dropout_module is not None: src_tokens = self.word_dropout_module(src_tokens) bsz, seqlen = src_tokens.size() # embed tokens x = self.embed_tokens(src_tokens) # Track token embeddings self.tracker.track(x, "token_embeddings", retain_grad=self.track_gradients) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # Allows compatibility with Caffe2 inputs for tracing (int32) # as well as the current format of Fairseq-Py inputs (int64) if src_lengths.dtype is torch.int64: src_lengths = src_lengths.int() # Generate packed seq to deal with varying source seq length # packed_input is of type PackedSequence, which consists of: # element [0]: a tensor, the packed data, and # element [1]: a list of integers, the batch size for each step packed_input = pack_padded_sequence(x, src_lengths) final_hiddens, final_cells = [], [] for i, rnn_layer in enumerate(self.layers): if self.bidirectional and i == 0: h0 = x.new(2, bsz, self.hidden_dim // 2).zero_() c0 = x.new(2, bsz, self.hidden_dim // 2).zero_() else: h0 = x.new(1, bsz, self.hidden_dim).zero_() c0 = x.new(1, bsz, self.hidden_dim).zero_() # apply LSTM along entire sequence current_output, (h_last, c_last) = rnn_layer(packed_input, (h0, c0)) # final state shapes: (bsz, hidden_dim) if self.bidirectional and i == 0: # concatenate last states for forward and backward LSTM h_last = torch.cat((h_last[0, :, :], h_last[1, :, :]), dim=1) c_last = torch.cat((c_last[0, :, :], c_last[1, :, :]), dim=1) else: h_last = h_last.squeeze(dim=0) c_last = c_last.squeeze(dim=0) final_hiddens.append(h_last) final_cells.append(c_last) if self.residual_level is not None and i >= self.residual_level: packed_input[0] = packed_input.clone()[0] + current_output[0] else: packed_input = current_output # Reshape to [num_layer, batch_size, hidden_dim] final_hiddens = torch.cat(final_hiddens, dim=0).view(self.num_layers, *final_hiddens[0].size()) final_cells = torch.cat(final_cells, dim=0).view(self.num_layers, *final_cells[0].size()) # [max_seqlen, batch_size, hidden_dim] unpacked_output, _ = pad_packed_sequence( packed_input, padding_value=self.padding_value) return (unpacked_output, final_hiddens, final_cells, src_lengths, src_tokens) def reorder_encoder_out(self, encoder_out, new_order): """Reorder all outputs according to new_order.""" return reorder_encoder_output(encoder_out, new_order) def max_positions(self): """Maximum input length supported by the encoder.""" return int(1e5) # an arbitrary large number def set_gradient_tracking_mode(self, mode=True): self.tracker.reset() self.track_gradients = mode
class TransformerEncoder(FairseqEncoder): """Transformer encoder.""" def __init__(self, args, dictionary, embed_tokens, left_pad=True): super().__init__(dictionary) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = fairseq_transformer.PositionalEmbedding( 1024, embed_dim, self.padding_idx, left_pad=left_pad, learned=args.encoder_learned_pos, ) self.layers = nn.ModuleList([]) self.layers.extend([ fairseq_transformer.TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) # Variable tracker self.tracker = VariableTracker() # Initialize adversarial mode self.set_gradient_tracking_mode(False) def forward(self, src_tokens, src_lengths): # Initialize the tracker to keep track of internal variables self.tracker.reset() # Embed tokens x = self.embed_tokens(src_tokens) # Track token embeddings self.tracker.track(x, "token_embeddings", retain_grad=self.track_gradients) # Add position embeddings and dropout x = self.embed_scale * x x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask (B x T) encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) return x, src_tokens, encoder_padding_mask def reorder_encoder_out(self, encoder_out, new_order): (x, src_tokens, encoder_padding_mask) = encoder_out if x is not None: x = x.index_select(1, new_order) if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) if encoder_padding_mask is not None: encoder_padding_mask = encoder_padding_mask.index_select( 0, new_order) return (x, src_tokens, encoder_padding_mask) def max_positions(self): """Maximum input length supported by the encoder.""" return self.embed_positions.max_positions() def upgrade_state_dict(self, state_dict): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if "encoder.embed_positions.weights" in state_dict: del state_dict["encoder.embed_positions.weights"] state_dict[ "encoder.embed_positions._float_tensor"] = torch.FloatTensor(1) return state_dict def set_gradient_tracking_mode(self, mode=True): self.tracker.reset() self.track_gradients = mode
class TransformerEncoder(FairseqEncoder): """Transformer encoder.""" def __init__( self, args, dictionary, embed_tokens, left_pad=False, proj_to_decoder=True ): super().__init__(dictionary) self.transformer_embedding = TransformerEmbedding( args=args, embed_tokens=embed_tokens, left_pad=left_pad ) self.transformer_encoder_given_embeddings = TransformerEncoderGivenEmbeddings( args=args, proj_to_decoder=proj_to_decoder ) # Variable tracker self.tracker = VariableTracker() # Initialize adversarial mode self.set_gradient_tracking_mode(False) self.set_embed_noising_mode(False) def forward(self, src_tokens, src_lengths): # Initialize the tracker to keep track of internal variables self.tracker.reset() x, encoder_padding_mask, positions = self.transformer_embedding( src_tokens=src_tokens, src_lengths=src_lengths ) # Track token embeddings self.tracker.track(x, "token_embeddings", retain_grad=self.track_gradients) x = self.transformer_encoder_given_embeddings( x=x, positions=positions, encoder_padding_mask=encoder_padding_mask ) return x, src_tokens, encoder_padding_mask def reorder_encoder_out(self, encoder_out, new_order): (x, src_tokens, encoder_padding_mask) = encoder_out if x is not None: x = x.index_select(1, new_order) if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) if encoder_padding_mask is not None: encoder_padding_mask = encoder_padding_mask.index_select(0, new_order) return (x, src_tokens, encoder_padding_mask) def max_positions(self): """Maximum input length supported by the encoder.""" return self.transformer_embedding.embed_positions.max_positions() def upgrade_state_dict_named(self, state_dict, name): if isinstance( self.transformer_embedding.embed_positions, SinusoidalPositionalEmbedding ): if f"{name}.transformer_embedding.embed_positions.weights" in state_dict: del state_dict[f"{name}.transformer_embedding.embed_positions.weights"] state_dict[ f"{name}.transformer_embedding.embed_positions._float_tensor" ] = torch.FloatTensor(1) return state_dict def set_gradient_tracking_mode(self, mode=True): self.tracker.reset() self.track_gradients = mode def set_embed_noising_mode(self, mode=True): """This allows adversarial trainer to turn on and off embedding noising layers. In regular training, this mode is off, and it is not included in forward pass. """ self.embed_noising_mode = mode
class CharCNNEncoder(FairseqEncoder): """ Character-level CNN encoder to generate word representations, as input to RNN encoder. """ def __init__( self, dictionary, num_chars=50, unk_only_char_encoding=False, embed_dim=32, token_embed_dim=256, freeze_embed=False, normalize_embed=False, char_cnn_params="[(128, 3), (128, 5)]", char_cnn_nonlinear_fn="tanh", char_cnn_pool_type="max", char_cnn_num_highway_layers=0, char_cnn_output_dim=-1, hidden_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, residual_level=None, bidirectional=False, word_dropout_params=None, use_pretrained_weights=False, finetune_pretrained_weights=False, weights_file=None, ): super().__init__(dictionary) self.dropout_in = dropout_in convolutions_params = literal_eval(char_cnn_params) self.char_cnn_encoder = char_encoder.CharCNNModel( dictionary, num_chars, embed_dim, convolutions_params, char_cnn_nonlinear_fn, char_cnn_pool_type, char_cnn_num_highway_layers, char_cnn_output_dim, use_pretrained_weights, finetune_pretrained_weights, weights_file, ) self.embed_tokens = None num_tokens = len(dictionary) self.padding_idx = dictionary.pad() self.unk_idx = dictionary.unk() if token_embed_dim > 0: self.embed_tokens = rnn.Embedding( num_embeddings=num_tokens, embedding_dim=token_embed_dim, padding_idx=self.padding_idx, freeze_embed=freeze_embed, normalize_embed=normalize_embed, ) self.word_dim = ( char_cnn_output_dim if char_cnn_output_dim != -1 else sum(out_dim for (out_dim, _) in convolutions_params) ) self.token_embed_dim = token_embed_dim self.unk_only_char_encoding = unk_only_char_encoding if self.unk_only_char_encoding: assert char_cnn_output_dim == token_embed_dim, ( "char_cnn_output_dim (%d) must equal to token_embed_dim (%d)" % (char_cnn_output_dim, token_embed_dim) ) self.word_dim = token_embed_dim else: self.word_dim = self.word_dim + token_embed_dim self.bilstm = rnn.BiLSTM( num_layers=num_layers, bidirectional=bidirectional, embed_dim=self.word_dim, hidden_dim=hidden_dim, dropout=dropout_out, residual_level=residual_level, ) # Variable tracker self.tracker = VariableTracker() # Initialize adversarial mode self.set_gradient_tracking_mode(False) self.set_embed_noising_mode(False) def set_gradient_tracking_mode(self, mode=True): """ This allows AdversarialTrainer to turn on retrain_grad when running adversarial example generation model.""" self.tracker.reset() self.track_gradients = mode def set_embed_noising_mode(self, mode=True): """This allows adversarial trainer to turn on and off embedding noising layers. In regular training, this mode is off, and it is not included in forward pass. """ self.embed_noising_mode = mode def forward(self, src_tokens, src_lengths, char_inds, word_lengths): self.tracker.reset() # char_inds has shape (batch_size, max_words_per_sent, max_word_len) bsz, seqlen, maxchars = char_inds.size() # char_cnn_encoder takes input (max_word_length, total_words) char_inds_flat = char_inds.view(-1, maxchars) # .t() # output (total_words, encoder_dim) if self.unk_only_char_encoding: assert ( self.embed_tokens is not None ), "token_embed_dim should > 0 when unk_only_char_encoding is true!" unk_masks = (src_tokens == self.unk_idx).view(-1) unk_indices = torch.nonzero(unk_masks).squeeze() if unk_indices.dim() > 0 and unk_indices.size(0) > 0: char_inds_flat = torch.index_select(char_inds_flat, 0, unk_indices) char_inds_flat = char_inds_flat.view(-1, maxchars) else: char_inds_flat = None if char_inds_flat is not None: # (bsz * seqlen, encoder_dim) char_cnn_output = self.char_cnn_encoder(char_inds_flat.t()) x = char_cnn_output else: # charCNN is not needed x = None if self.embed_tokens is not None: # (bsz, seqlen, token_embed_dim) embedded_tokens = self.embed_tokens(src_tokens) # (bsz * seqlen, token_embed_dim) embedded_tokens = embedded_tokens.view(-1, self.token_embed_dim) if self.unk_only_char_encoding: # charCNN for UNK words only if x is not None: x = embedded_tokens.index_copy(0, unk_indices, x) else: # no UNK, so charCNN is not needed x = embedded_tokens else: # charCNN for all words x = torch.cat([x, embedded_tokens], dim=1) # (bsz, seqlen, x.shape[-1]) x = x.view(bsz, seqlen, x.shape[-1]) # (seqlen, bsz, x.shape[-1]) x = x.transpose(0, 1) self.tracker.track(x, "token_embeddings", retain_grad=self.track_gradients) if self.dropout_in != 0: x = F.dropout(x, p=self.dropout_in, training=self.training) embedded_words = x unpacked_output, final_hiddens, final_cells = self.bilstm( embeddings=x, lengths=src_lengths ) return ( unpacked_output, final_hiddens, final_cells, src_lengths, src_tokens, embedded_words, ) def reorder_encoder_out(self, encoder_out, new_order): """Reorder all outputs according to new_order.""" return rnn.reorder_encoder_output(encoder_out, new_order) def max_positions(self): """Maximum input length supported by the encoder.""" return int(1e5) # an arbitrary large number
class CharCNNEncoder(FairseqEncoder): """ Character-level CNN encoder to generate word representations, as input to RNN encoder. """ def __init__( self, dictionary, num_chars=50, embed_dim=32, token_embed_dim=256, freeze_embed=False, char_cnn_params="[(128, 3), (128, 5)]", char_cnn_output_dim=256, char_cnn_nonlinear_fn="tanh", char_cnn_pool_type="max", char_cnn_num_highway_layers=0, hidden_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, residual_level=None, bidirectional=False, word_dropout_params=None, ): super().__init__(dictionary) self.dictionary = dictionary self.dropout_in = dropout_in self.dropout_out = dropout_out self.residual_level = residual_level self.hidden_dim = hidden_dim self.bidirectional = bidirectional convolutions_params = literal_eval(char_cnn_params) self.char_cnn_encoder = char_encoder.CharCNNModel( dictionary, num_chars, embed_dim, convolutions_params, char_cnn_nonlinear_fn, char_cnn_pool_type, char_cnn_num_highway_layers, ) self.embed_tokens = None num_tokens = len(dictionary) self.padding_idx = dictionary.pad() if token_embed_dim > 0: self.embed_tokens = rnn.Embedding( num_embeddings=num_tokens, embedding_dim=token_embed_dim, padding_idx=self.padding_idx, freeze_embed=freeze_embed, ) self.word_dim = (sum(out_dim for (out_dim, _) in convolutions_params) + token_embed_dim) self.layers = nn.ModuleList([]) for layer in range(num_layers): is_layer_bidirectional = self.bidirectional and layer == 0 if is_layer_bidirectional: assert hidden_dim % 2 == 0, ( "encoder_hidden_dim must be even if encoder_bidirectional " "(to be divided evenly between directions)") self.layers.append( rnn.LSTMSequenceEncoder.LSTM( self.word_dim if layer == 0 else hidden_dim, hidden_dim // 2 if is_layer_bidirectional else hidden_dim, num_layers=1, dropout=self.dropout_out, bidirectional=is_layer_bidirectional, )) self.num_layers = len(self.layers) self.word_dropout_module = None if (word_dropout_params and word_dropout_params["word_dropout_freq_threshold"] is not None and word_dropout_params["word_dropout_freq_threshold"] > 0): self.word_dropout_module = word_dropout.WordDropout( dictionary, word_dropout_params) # Variable tracker self.tracker = VariableTracker() # Initialize adversarial mode self.set_gradient_tracking_mode(False) def set_gradient_tracking_mode(self, mode=True): """ This allows AdversarialTrainer to turn on retrain_grad when running adversarial example generation model.""" self.tracker.reset() self.track_gradients = mode def forward(self, src_tokens, src_lengths, char_inds, word_lengths): self.tracker.reset() # char_inds has shape (batch_size, max_words_per_sent, max_word_len) bsz, seqlen, maxchars = char_inds.size() # char_cnn_encoder takes input (max_word_length, total_words) char_inds_flat = char_inds.view(-1, maxchars).t() # output (total_words, encoder_dim) char_cnn_output = self.char_cnn_encoder(char_inds_flat) x = char_cnn_output.view(bsz, seqlen, char_cnn_output.shape[-1]) x = x.transpose(0, 1) # (seqlen, bsz, char_cnn_output_dim) if self.embed_tokens is not None: embedded_tokens = self.embed_tokens(src_tokens) # (seqlen, bsz, token_embed_dim) embedded_tokens = embedded_tokens.transpose(0, 1) # (seqlen, bsz, total_word_embed_dim) x = torch.cat([x, embedded_tokens], dim=2) self.tracker.track(x, "token_embeddings", retain_grad=self.track_gradients) if self.dropout_in != 0: x = F.dropout(x, p=self.dropout_in, training=self.training) embedded_words = x # The rest is the same as CharRNNEncoder, so could be refactored # Generate packed seq to deal with varying source seq length # packed_input is of type PackedSequence, which consists of: # element [0]: a tensor, the packed data, and # element [1]: a list of integers, the batch size for each step packed_input = pack_padded_sequence(x, src_lengths) final_hiddens, final_cells = [], [] for i, rnn_layer in enumerate(self.layers): if self.bidirectional and i == 0: h0 = x.data.new(2, bsz, self.hidden_dim // 2).zero_() c0 = x.data.new(2, bsz, self.hidden_dim // 2).zero_() else: h0 = x.data.new(1, bsz, self.hidden_dim).zero_() c0 = x.data.new(1, bsz, self.hidden_dim).zero_() # apply LSTM along entire sequence current_output, (h_last, c_last) = rnn_layer(packed_input, (h0, c0)) # final state shapes: (bsz, hidden_dim) if self.bidirectional and i == 0: # concatenate last states for forward and backward LSTM h_last = torch.cat((h_last[0, :, :], h_last[1, :, :]), dim=1) c_last = torch.cat((c_last[0, :, :], c_last[1, :, :]), dim=1) else: h_last = h_last.squeeze(dim=0) c_last = c_last.squeeze(dim=0) final_hiddens.append(h_last) final_cells.append(c_last) if self.residual_level is not None and i >= self.residual_level: packed_input[0] = packed_input.clone()[0] + current_output[0] else: packed_input = current_output # Reshape to [num_layer, batch_size, hidden_dim] final_hiddens = torch.cat(final_hiddens, dim=0).view(self.num_layers, *final_hiddens[0].size()) final_cells = torch.cat(final_cells, dim=0).view(self.num_layers, *final_cells[0].size()) # [max_seqlen, batch_size, hidden_dim] unpacked_output, _ = pad_packed_sequence(packed_input) return ( unpacked_output, final_hiddens, final_cells, src_lengths, src_tokens, embedded_words, ) def reorder_encoder_out(self, encoder_out, new_order): """Reorder all outputs according to new_order.""" return rnn.reorder_encoder_output(encoder_out, new_order) def max_positions(self): """Maximum input length supported by the encoder.""" return int(1e5) # an arbitrary large number