def __init__(self, config, dataset): ''' Initialize the Transformer ''' super(NewTransformer, self).__init__() self.dataset = dataset self.embedding = TokenEmbedding( dataset.vocab_size, config.embedding_size, padding_idx=self.padding_idx ) self.position_embedding = PositionEmbedding(config.embedding_size) self.dropout = nn.Dropout(config.dropout_p, inplace=True) # Uniq attn attributes self.attn_ofs_uniq = list(set( config.enc_attn_offset + config.dec_attn_offset + config.enc_dec_attn_offset)) self.attn_std_uniq = list(set( config.enc_attn_std + config.dec_attn_std + config.enc_dec_attn_std)) # Allow for overriding the encoders and decoders in dervied classes self.encoders = self.create_encoders(config) self.decoders = self.create_decoders(config) self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none' ) self.cross_entropy = nn.CrossEntropyLoss( ignore_index=self.padding_idx, reduction='none' )
def __init__(self, config, dataset): ''' Initialize the Transformer ''' super(ProbeNewTransformer, self).__init__() self.dataset = dataset self.span = config.span self.embedding = TokenEmbedding( dataset.vocab_size, config.embedding_size, padding_idx=self.padding_idx ) self.position_embedding = PositionEmbedding(config.embedding_size) self.dropout = nn.Dropout(config.dropout_p, inplace=True) # Allow for overriding the encoders and decoders in dervied classes self.encoders = type(self).create_encoders(config) self.decoders = self.create_decoders(config) self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none' ) self.cross_entropy = nn.CrossEntropyLoss( ignore_index=self.padding_idx, reduction='none' )
def __init__(self, config, dataset): ''' Initialize''' super(NPLM, self).__init__() self.dataset = dataset self.adaptive = config.adaptive # ngm: n tokens that concat with full emb # wsz: window size to average for long term context self.ngm, self.wsz = config.context_config self.long_term_block = 0 if self.ngm > 0 and self.wsz == -1 else \ (config.batch_length - self.ngm) // self.wsz self.dim_concat_embs = self.ngm * config.embedding_size + self.long_term_block * config.embedding_size self.embedding = TokenEmbedding( dataset.vocab_size, config.embedding_size, config.model_size, config.cutoffs, emb_std=config.emb_std, proj_std = config.proj_std, div_val=config.div_val, padding_idx=self.padding_idx, do_proj=config.do_proj ) if self.adaptive: self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size, config.embedding_size, config.embedding_size, config.cutoffs, div_val=config.div_val) self.tie_weights = config.tie_weights self.tie_projs = config.tie_projs if self.tie_weights: for i in range(len(self.adaptive_softmax.out_layers)): self.adaptive_softmax.out_layers[i].weight = self.embedding.emb_layers[i].weight if self.tie_projs: for i in range(1, len(self.adaptive_softmax.out_projs)): if config.div_val == 1 and config.model_size != config.embedding_size: self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[0] elif config.div_val != 1: self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[i] self.layers = self.create_layers(config) self.position_embedding = PositionEmbedding(config.model_size) # only used in transformer-N self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none' ) self.cross_entropy = nn.CrossEntropyLoss( ignore_index=self.padding_idx, reduction='none' ) self.dropout = nn.Dropout(config.dropout_p, inplace=True) self.config = config
def __init__(self, config, dataset): ''' Initialize the Transformer ''' super(Transformer, self).__init__() self.dataset = dataset self.config = config self.adaptive = config.adaptive self.embedding = TokenEmbedding(dataset.vocab_size, config.embedding_size, config.model_size, config.cutoffs, emb_std=config.emb_std, proj_std=config.proj_std, div_val=config.div_val, padding_idx=self.padding_idx, do_proj=config.do_proj) if self.adaptive: self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size, config.embedding_size, config.model_size, config.cutoffs, div_val=config.div_val) self.tie_weights = config.tie_weights self.tie_projs = config.tie_projs if self.tie_weights: for i in range(len(self.adaptive_softmax.out_layers)): self.adaptive_softmax.out_layers[ i].weight = self.embedding.emb_layers[i].weight if self.tie_projs: for i in range(1, len(self.adaptive_softmax.out_projs)): if config.div_val == 1 and config.model_size != config.embedding_size: self.adaptive_softmax.out_projs[ i] = self.embedding.emb_projs[0] elif config.div_val != 1: self.adaptive_softmax.out_projs[ i] = self.embedding.emb_projs[i] self.position_embedding = PositionEmbedding(config.embedding_size) self.dropout = nn.Dropout(config.dropout_p, inplace=True) if len(config.no_attention) == 1: config.no_attention = config.no_attention * config.num_layers assert len(config.no_attention) == config.num_layers self.layers = self.create_layers(config) self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none') self.cross_entropy = nn.CrossEntropyLoss(ignore_index=self.padding_idx, reduction='none')
def __init__(self, config, dataset): ''' Initialize the Transformer ''' super(InterleaveFixedPosEmbEncoderOnlyTransformer, self).__init__() self.dataset = dataset self.embedding = TokenEmbedding(dataset.vocab_size, config.embedding_size, padding_idx=self.padding_idx) self.position_embedding = PositionEmbedding(config.embedding_size) self.num_layers = config.num_layers encoder_positional_embedding_list = [] for i in range(self.num_layers // 2): position_embedding_encoder = LearnedPositionalEmbedding( dataset.max_input_length, config.embedding_size, self.padding_idx) nn.init.normal_(position_embedding_encoder.weight, mean=0, std=config.embedding_size**-0.5) if self.padding_idx is not None: nn.init.constant_( position_embedding_encoder.weight[self.padding_idx], 0) encoder_positional_embedding_list.append( position_embedding_encoder) self.encoder_positional_embeddings = nn.ModuleList( encoder_positional_embedding_list) self.position_embedding_decoder = LearnedPositionalEmbedding( dataset.max_target_length, config.embedding_size, self.padding_idx) nn.init.normal_(self.position_embedding_decoder.weight, mean=0, std=config.embedding_size**-0.5) if self.padding_idx is not None: nn.init.constant_( self.position_embedding_decoder.weight[self.padding_idx], 0) self.dropout = nn.Dropout(config.dropout_p, inplace=True) # Uniq attn attributes self.attn_ofs_uniq = list( set(config.enc_attn_offset + config.dec_attn_offset + config.enc_dec_attn_offset)) self.attn_std_uniq = list( set(config.enc_attn_std + config.dec_attn_std + config.enc_dec_attn_std)) # Allow for overriding the encoders and decoders in dervied classes self.encoders = self.create_encoders(config) self.decoders = self.create_decoders(config) self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none') self.cross_entropy = nn.CrossEntropyLoss(ignore_index=self.padding_idx, reduction='none')
def __init__(self, config, dataset): ''' Initialize the ParseTransformer ''' super(ParseTransformer, self).__init__(config, dataset) self.span = 1 args = [config.num_heads, config.embedding_size, config.hidden_dim] self.annotation_decoders = nn.ModuleList([ TransformerDecoderLayer(*args, dropout_p=config.dropout_p) for _ in range(config.parse_num_layers) ]) self.annotation_embedding = TokenEmbedding( dataset.annotation_vocab_size, config.embedding_size, padding_idx=self.annotation_padding_idx) self.annotation_cross_entropy = nn.CrossEntropyLoss( ignore_index=self.annotation_padding_idx, reduction='none')
class Transformer(nn.Module): ''' The Transformer module ''' def __init__(self, config, dataset): ''' Initialize the Transformer ''' super(Transformer, self).__init__() self.dataset = dataset self.span = config.span self.embedding = TokenEmbedding(dataset.vocab_size, config.embedding_size, padding_idx=self.padding_idx) self.position_embedding = PositionEmbedding(config.embedding_size) self.dropout = nn.Dropout(config.dropout_p, inplace=True) # Allow for overriding the encoders and decoders in dervied classes self.encoders = type(self).create_encoders(config) self.decoders = type(self).create_decoders(config) self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none') self.cross_entropy = nn.CrossEntropyLoss(ignore_index=self.padding_idx, reduction='none') @classmethod def create_encoders(cls, config): ''' Create the transformer encoders ''' kwargs = {'dropout_p': config.dropout_p} args = [config.num_heads, config.embedding_size, config.hidden_dim] return nn.ModuleList([ TransformerEncoderLayer(*args, **kwargs) for _ in range(config.num_layers) ]) @classmethod def create_decoders(cls, config): ''' Create the transformer decoders ''' kwargs = {'dropout_p': config.dropout_p, 'span': config.span} args = [config.num_heads, config.embedding_size, config.hidden_dim] return nn.ModuleList([ TransformerDecoderLayer(*args, **kwargs) for _ in range(config.num_layers) ]) @property def sos_idx(self): ''' Return the sos index ''' return self.dataset.sos_idx @property def padding_idx(self): ''' Return the padding index ''' return self.dataset.padding_idx def translator(self, config): ''' Get a translator for this model ''' return Translator(config, self, self.dataset) def reset_named_parameters(self, modules): ''' Get a translator for this model ''' if 'encoder' in modules: for encoder in self.encoders: encoder.reset_parameters() if 'decoder' in modules: for decoder in self.decoders: decoder.reset_parameters() if 'embeddings' in modules: self.embedding.reset_parameters() def forward(self, batch): # pylint:disable=arguments-differ ''' A batch of inputs and targets ''' decoded = self.decode( self.encode(batch['inputs']), right_shift(right_shift(batch['targets']), shift=self.span - 1, fill=self.sos_idx), ) logits = decoded['logits'] dims = list(range(1, logits.dim())) targets = left_shift(batch['targets']) nll = self.cross_entropy(logits, targets).sum(dims[:-1]) smoothed_nll = self.label_smoothing(logits, targets).sum(dims) return smoothed_nll, nll def encode(self, inputs): ''' Encode the inputs ''' encoded = { 'state': self.embed(inputs, self.embedding), 'mask': inputs.eq(self.padding_idx) } for encoder in self.encoders: encoded = encoder(encoded) return encoded def decode(self, encoded, targets, decoders=None, embedding=None, cache=None, mask=None): ''' Decode the encoded sequence to the targets ''' if decoders is None: decoders = self.decoders if embedding is None: embedding = self.embedding decoded = { 'cache': cache, 'state': self.embed(targets, embedding), 'mask': targets.eq(self.padding_idx) if mask is None else mask } for decoder in decoders: decoded = decoder(decoded, encoded) # compute projection to the vocabulary state = decoded['state'] if cache is not None: state = state[:, -self.span:] return { 'cache': decoded.get('cache'), 'logits': embedding(state, transpose=True).transpose(2, 1), # transpose to B x C x ... } def embed(self, inputs, token_embedding): ''' Embed the given inputs ''' return self.dropout( token_embedding(inputs) + self.position_embedding(inputs))
class NewTransformer(nn.Module): ''' The New Transformer module ''' def __init__(self, config, dataset): ''' Initialize the Transformer ''' super(NewTransformer, self).__init__() self.dataset = dataset self.embedding = TokenEmbedding( dataset.vocab_size, config.embedding_size, padding_idx=self.padding_idx ) self.position_embedding = PositionEmbedding(config.embedding_size) self.dropout = nn.Dropout(config.dropout_p, inplace=True) # Uniq attn attributes self.attn_ofs_uniq = list(set( config.enc_attn_offset + config.dec_attn_offset + config.enc_dec_attn_offset)) self.attn_std_uniq = list(set( config.enc_attn_std + config.dec_attn_std + config.enc_dec_attn_std)) # Allow for overriding the encoders and decoders in dervied classes self.encoders = self.create_encoders(config) self.decoders = self.create_decoders(config) self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none' ) self.cross_entropy = nn.CrossEntropyLoss( ignore_index=self.padding_idx, reduction='none' ) def create_encoders(self, config): ''' Create the transformer encoders ''' kwargs = {'dropout_p': config.dropout_p} if config.ffn_layer == -1: config.ffn_layer = [1] * config.num_layers assert len(config.ffn_layer) == config.num_layers attn_config = {'attn_type': config.enc_attn_type, 'attn_std': config.enc_attn_std, 'attn_offset': config.enc_attn_offset, 'num_layers': config.num_layers, 'num_heads': config.num_heads, 'which_attn': 'encoder', 'attn_threshold': config.enc_attn_threshold, 'attn_window': config.enc_attn_window, 'attn_impl': config.enc_attn_impl, 'ffn_layer': config.ffn_layer, 'attn_ofs_uniq': self.attn_ofs_uniq, 'attn_std_uniq': self.attn_std_uniq} args = [attn_config, config.num_heads, config.embedding_size, config.hidden_dim] encoders = nn.ModuleList([ TransformerEncoderLayer(*args, layer_i, **kwargs) for layer_i in range(config.num_layers) ]) return encoders def create_decoders(self, config): ''' Create the transformer decoders ''' kwargs = {'dropout_p': config.dropout_p} if config.ffn_layer == -1: config.ffn_layer = [1] * config.num_layers assert len(config.ffn_layer) == config.num_layers dec_attn_config = {'attn_type': config.dec_attn_type, 'attn_std': config.dec_attn_std, 'attn_offset': config.dec_attn_offset, 'num_layers': config.num_layers, 'num_heads': config.num_heads, 'which_attn': 'decoder', 'attn_threshold': config.dec_attn_threshold, 'attn_window': config.dec_attn_window, 'attn_impl': config.dec_attn_impl, 'ffn_layer': config.ffn_layer, 'attn_ofs_uniq': self.attn_ofs_uniq, 'attn_std_uniq': self.attn_std_uniq } enc_dec_attn_config = {'attn_type': config.enc_dec_attn_type, 'attn_std': config.enc_dec_attn_std, 'attn_offset': config.enc_dec_attn_offset, 'num_layers': config.num_layers, 'num_heads': config.num_heads, 'word_count_ratio': self.dataset.word_count_ratio, 'which_attn': 'source', 'enc_dec_attn_layer': config.enc_dec_attn_layer, 'enc_dec_attn_num_heads': config.enc_dec_attn_num_heads, 'attn_threshold': config.enc_dec_attn_threshold, 'attn_window': config.enc_dec_attn_window, 'attn_impl': config.enc_dec_attn_impl, 'ffn_layer': config.ffn_layer, 'attn_ofs_uniq': self.attn_ofs_uniq, 'attn_std_uniq': self.attn_std_uniq } args = [dec_attn_config, enc_dec_attn_config, config.num_heads, config.embedding_size, config.hidden_dim] decoders = nn.ModuleList([ TransformerDecoderLayer(*args, layer_i, **kwargs) for layer_i in range(config.num_layers) ]) return decoders @property def sos_idx(self): ''' Return the sos index ''' return self.dataset.sos_idx @property def padding_idx(self): ''' Return the padding index ''' return self.dataset.padding_idx def translator(self, config): ''' Get a translator for this model ''' return Translator(config, self, self.dataset) def reset_named_parameters(self, modules): ''' Get a translator for this model ''' if 'encoder' in modules: for encoder in self.encoders: encoder.reset_parameters() if 'decoder' in modules: for decoder in self.decoders: decoder.reset_parameters() if 'embeddings' in modules: self.embedding.reset_parameters() def forward(self, batch): # pylint:disable=arguments-differ ''' A batch of inputs and targets ''' decoded = self.decode( self.encode(batch['inputs']), right_shift(batch['targets']), input_lens=batch['input_lens'] ) logits = decoded['logits'] dims = list(range(1, logits.dim())) targets = left_shift(batch['targets']) nll = self.cross_entropy(logits, targets).sum(dims[:-1]) smoothed_nll = self.label_smoothing(logits, targets).sum(dims) return smoothed_nll, nll def encode(self, inputs): ''' Encode the inputs ''' word_embedding = self.embed(inputs, self.embedding) encoded = { 'state': word_embedding, 'mask': inputs.eq(self.padding_idx) } for i, encoder in enumerate(self.encoders): encoded = encoder(encoded, i) return encoded def decode(self, encoded, targets, decoders=None, embedding=None, cache=None, mask=None, input_lens=None): ''' Decode the encoded sequence to the targets ''' if decoders is None: decoders = self.decoders if embedding is None: embedding = self.embedding word_embedding = self.embed(targets, embedding) decoded = { 'cache': cache, 'state': word_embedding, 'mask': targets.eq(self.padding_idx) if mask is None else mask } for i, decoder in enumerate(decoders): # print("i", i) decoded = decoder(decoded, encoded, i) # compute projection to the vocabulary state = decoded['state'] if cache is not None: state = state[:, -1:] return { 'cache': decoded.get('cache'), 'logits': embedding(state, transpose=True).transpose(2, 1), # transpose to B x C x ... } def embed(self, inputs, token_embedding): ''' Embed the given inputs ''' return self.dropout(token_embedding(inputs) + self.position_embedding(inputs))
class NPLM(nn.Module): ''' The neural proababilistic LM module ''' def __init__(self, config, dataset): ''' Initialize''' super(NPLM, self).__init__() self.dataset = dataset self.adaptive = config.adaptive # ngm: n tokens that concat with full emb # wsz: window size to average for long term context self.ngm, self.wsz = config.context_config self.long_term_block = 0 if self.ngm > 0 and self.wsz == -1 else \ (config.batch_length - self.ngm) // self.wsz self.dim_concat_embs = self.ngm * config.embedding_size + self.long_term_block * config.embedding_size self.embedding = TokenEmbedding( dataset.vocab_size, config.embedding_size, config.model_size, config.cutoffs, emb_std=config.emb_std, proj_std = config.proj_std, div_val=config.div_val, padding_idx=self.padding_idx, do_proj=config.do_proj ) if self.adaptive: self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size, config.embedding_size, config.embedding_size, config.cutoffs, div_val=config.div_val) self.tie_weights = config.tie_weights self.tie_projs = config.tie_projs if self.tie_weights: for i in range(len(self.adaptive_softmax.out_layers)): self.adaptive_softmax.out_layers[i].weight = self.embedding.emb_layers[i].weight if self.tie_projs: for i in range(1, len(self.adaptive_softmax.out_projs)): if config.div_val == 1 and config.model_size != config.embedding_size: self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[0] elif config.div_val != 1: self.adaptive_softmax.out_projs[i] = self.embedding.emb_projs[i] self.layers = self.create_layers(config) self.position_embedding = PositionEmbedding(config.model_size) # only used in transformer-N self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none' ) self.cross_entropy = nn.CrossEntropyLoss( ignore_index=self.padding_idx, reduction='none' ) self.dropout = nn.Dropout(config.dropout_p, inplace=True) self.config = config @classmethod def create_layers(self, config): ''' Create the NPLM decoders ''' kwargs = {'dropout_p': config.dropout_p} # sublayer kwargs args = [config, config.num_heads, config.embedding_size, config.hidden_dim] layers = nn.ModuleList([ NPLMLayer(*args, layer_i, **kwargs) for layer_i in range(config.num_layers) ]) return layers @property def padding_idx(self): return self.dataset.padding_idx @property def eos_idx(self): return self.dataset.eos_idx def reset_named_parameters(self, modules): if 'layers' in modules: for layer in self.layers: layer.reset_parameters() if 'embeddings' in modules: self.embedding.reset_parameters() def forward(self, batch): # pylint:disable=arguments-differ batch = batch.t() targets = left_shift(batch) decoded = self.decode(right_shift(batch)) state = decoded['state'] if not self.adaptive: logits = self.embedding(state, reverse=True).transpose(2, 1) dims = list(range(1, logits.dim())) nll = self.cross_entropy(logits, targets).view(-1) smoothed_nll = self.label_smoothing(logits, targets).sum(dims) if not self.config.return_rank: return smoothed_nll, nll else: logits = logits.transpose(2, 1) assert targets.shape[0] == 1 targets = targets.squeeze(0) target_logits = logits[:, range(targets.shape[0]), targets] rank = (logits > target_logits.unsqueeze(-1)).sum(dim=-1) return rank, nll else: state = state.view(-1, state.shape[-1]) # (bsz*L, embed_dim) targets = targets.contiguous().view(-1) # (bsz*L, ) if not self.config.return_rank: nll = self.adaptive_softmax(state, targets, keep_order=True) smoothed_nll = nll return smoothed_nll, nll else: nll, rank = self.adaptive_softmax(state, targets, keep_order=True, return_rank=True) return rank, nll return smoothed_nll, nll def decode(self, batch, cache=None): ''' if targest is not None, ''' word_embedding = self.embed(batch, self.embedding) decoded = { 'cache': cache, 'state': word_embedding, } # concat layer decoded = self.layers[0](decoded, layer_i=0) global_mem = self.layers[0].global_mem # regular layers for i, decoder in enumerate(self.layers[1:]): decoded = decoder(decoded, layer_i=i+1, global_mem=global_mem) # compute projection to the vocabulary state = decoded['state'] if cache is not None: state = state[:, -1:] # fetch newly generated tok return { 'cache': decoded.get('cache'), 'state': state, # bs x L x dim_emb or bs x L x hidden_dim } def embed(self, inputs, token_embedding): ''' Embed the given inputs, no position embedding ''' if self.config.TFN: return self.dropout(token_embedding(inputs) + self.position_embedding(inputs)) else: return self.dropout(token_embedding(inputs))
class ProbeNewTransformer(nn.Module): ''' The New Transformer module ''' def __init__(self, config, dataset): ''' Initialize the Transformer ''' super(ProbeNewTransformer, self).__init__() self.dataset = dataset self.span = config.span self.embedding = TokenEmbedding( dataset.vocab_size, config.embedding_size, padding_idx=self.padding_idx ) self.position_embedding = PositionEmbedding(config.embedding_size) self.dropout = nn.Dropout(config.dropout_p, inplace=True) # Allow for overriding the encoders and decoders in dervied classes self.encoders = type(self).create_encoders(config) self.decoders = self.create_decoders(config) self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none' ) self.cross_entropy = nn.CrossEntropyLoss( ignore_index=self.padding_idx, reduction='none' ) @classmethod def create_encoders(cls, config): ''' Create the transformer encoders ''' kwargs = {'dropout_p': config.dropout_p} attn_config = {'attn_type': config.attn_type, 'attn_position': config.attn_position, 'attn_param': config.attn_param, 'attn_displacement': config.attn_displacement, 'num_layers': config.num_layers, 'num_heads': config.num_heads, 'attn_concat': config.attn_concat, 'which_attn': 'encoder', 'attn_weights': config.attn_weights, 'attn_score': config.attn_score, 'attn_bins': config.attn_bins, 'attn_threshold': config.attn_threshold, 'attn_window': config.attn_window} args = [attn_config, config.num_heads, config.embedding_size, config.hidden_dim] return nn.ModuleList([ TransformerEncoderLayer(*args, **kwargs) for _ in range(config.num_layers) ]) # @classmethod def create_decoders(self, config): ''' Create the transformer decoders ''' kwargs = {'dropout_p': config.dropout_p, 'span': config.span} dec_attn_config = {'attn_type': config.dec_attn_type, 'attn_position': config.dec_attn_position, 'attn_param': config.dec_attn_param, 'attn_displacement': config.dec_attn_displacement, 'num_layers': config.num_layers, 'num_heads': config.num_heads, 'attn_concat': config.dec_attn_concat, 'which_attn': 'decoder', 'attn_weights': config.dec_attn_weights, 'attn_score': config.dec_attn_score, 'attn_bins': config.dec_attn_bins, 'attn_threshold': config.dec_attn_threshold, 'attn_window': config.dec_attn_window} enc_dec_attn_config = {'attn_type': config.enc_dec_attn_type, 'attn_position': config.enc_dec_attn_position, 'attn_param': config.enc_dec_attn_param, 'attn_displacement': config.enc_dec_attn_displacement, 'num_layers': config.num_layers, 'num_heads': config.num_heads, 'word_count_ratio': self.dataset.word_count_ratio, 'attn_concat': config.enc_dec_attn_concat, 'which_attn': 'source', 'attn_weights': config.enc_dec_attn_weights, 'attn_score': config.enc_dec_attn_score, 'attn_bins': config.enc_dec_attn_bins, 'enc_dec_attn_layer': config.enc_dec_attn_layer, 'enc_dec_attn_num_heads': config.enc_dec_attn_num_heads, 'attn_threshold': config.enc_dec_attn_threshold, 'attn_window': config.enc_dec_attn_window } args = [dec_attn_config, enc_dec_attn_config, config.num_heads, config.embedding_size, config.hidden_dim] return nn.ModuleList([ TransformerDecoderLayer(*args, layer_i, **kwargs) for layer_i in range(config.num_layers) ]) @property def sos_idx(self): ''' Return the sos index ''' return self.dataset.sos_idx @property def padding_idx(self): ''' Return the padding index ''' return self.dataset.padding_idx def translator(self, config): ''' Get a translator for this model ''' return ProbeNewTranslator(config, self, self.dataset) def reset_named_parameters(self, modules): ''' Get a translator for this model ''' if 'encoder' in modules: for encoder in self.encoders: encoder.reset_parameters() if 'decoder' in modules: for decoder in self.decoders: decoder.reset_parameters() if 'embeddings' in modules: self.embedding.reset_parameters() def forward(self, batch): # pylint:disable=arguments-differ ''' A batch of inputs and targets ''' encoded, encoder_attn_weights_tensor = self.encode(batch['inputs']) decoded = self.decode( encoded, right_shift(right_shift(batch['targets']), shift=self.span - 1, fill=self.sos_idx), input_lens=batch['input_lens'] ) logits = decoded['logits'] dims = list(range(1, logits.dim())) targets = left_shift(batch['targets']) nll = self.cross_entropy(logits, targets).sum(dims[:-1]) smoothed_nll = self.label_smoothing(logits, targets).sum(dims) return {'smoothed_nll': smoothed_nll, 'nll': nll, 'encoder_attn_weights_tensor': encoder_attn_weights_tensor, 'decoder_attn_weights_tensor': decoded['decoder_attn_weights_tensor'], 'enc_dec_attn_weights_tensor': decoded['enc_dec_attn_weights_tensor']} def encode(self, inputs): ''' Encode the inputs ''' word_embedding = self.embed(inputs, self.embedding) encoded = { 'state': word_embedding, 'mask': inputs.eq(self.padding_idx) } encoder_attn_weights_list = [] for i, encoder in enumerate(self.encoders): encoded = encoder(encoded, i, word_embedding) encoder_attn_weights_list.append(encoded['encoder_attn_weights']) encoder_attn_weights_tensor = torch.stack(encoder_attn_weights_list) return encoded, encoder_attn_weights_tensor def decode(self, encoded, targets, decoders=None, embedding=None, cache=None, mask=None, input_lens=None): ''' Decode the encoded sequence to the targets ''' if decoders is None: decoders = self.decoders if embedding is None: embedding = self.embedding word_embedding = self.embed(targets, embedding) decoded = { 'cache': cache, 'state': word_embedding, 'mask': targets.eq(self.padding_idx) if mask is None else mask, 'input_lens': input_lens } decoder_attn_weights_list = [] enc_dec_attn_weights_list = [] for i, decoder in enumerate(decoders): # print("i", i) decoded = decoder(decoded, encoded, i, word_embedding) if 'enc_dec_attn_weights' not in decoded: decoder_attn_weights_list.append(decoded['decoder_attn_weights']) # enc_dec_attn_weights_list.append([]) else: decoder_attn_weights_list.append(decoded['decoder_attn_weights']) enc_dec_attn_weights_list.append(decoded['enc_dec_attn_weights']) decoder_attn_weights_tensor = torch.stack(decoder_attn_weights_list) enc_dec_attn_weights_tensor = torch.stack(enc_dec_attn_weights_list) # compute projection to the vocabulary state = decoded['state'] if cache is not None: state = state[:, -self.span:] return { 'cache': decoded.get('cache'), 'logits': embedding(state, transpose=True).transpose(2, 1), # transpose to B x C x ... 'decoder_attn_weights_tensor': decoder_attn_weights_tensor, 'enc_dec_attn_weights_tensor': enc_dec_attn_weights_tensor } def embed(self, inputs, token_embedding): ''' Embed the given inputs ''' return self.dropout(token_embedding(inputs) + self.position_embedding(inputs))
class Transformer(nn.Module): ''' The Transformer LM module ''' def __init__(self, config, dataset): ''' Initialize the Transformer ''' super(Transformer, self).__init__() self.dataset = dataset self.config = config self.adaptive = config.adaptive self.embedding = TokenEmbedding(dataset.vocab_size, config.embedding_size, config.model_size, config.cutoffs, emb_std=config.emb_std, proj_std=config.proj_std, div_val=config.div_val, padding_idx=self.padding_idx, do_proj=config.do_proj) if self.adaptive: self.adaptive_softmax = AdaptiveSoftmax(self.dataset.vocab_size, config.embedding_size, config.model_size, config.cutoffs, div_val=config.div_val) self.tie_weights = config.tie_weights self.tie_projs = config.tie_projs if self.tie_weights: for i in range(len(self.adaptive_softmax.out_layers)): self.adaptive_softmax.out_layers[ i].weight = self.embedding.emb_layers[i].weight if self.tie_projs: for i in range(1, len(self.adaptive_softmax.out_projs)): if config.div_val == 1 and config.model_size != config.embedding_size: self.adaptive_softmax.out_projs[ i] = self.embedding.emb_projs[0] elif config.div_val != 1: self.adaptive_softmax.out_projs[ i] = self.embedding.emb_projs[i] self.position_embedding = PositionEmbedding(config.embedding_size) self.dropout = nn.Dropout(config.dropout_p, inplace=True) if len(config.no_attention) == 1: config.no_attention = config.no_attention * config.num_layers assert len(config.no_attention) == config.num_layers self.layers = self.create_layers(config) self.label_smoothing = LabelSmoothingLoss( config.label_smoothing or 0, ignore_index=self.padding_idx, reduction='none') self.cross_entropy = nn.CrossEntropyLoss(ignore_index=self.padding_idx, reduction='none') @classmethod def create_layers(self, config): ''' Create the transformer decoders ''' kwargs = {'dropout_p': config.dropout_p} # sublayer kwargs args = [config, config.num_heads, config.model_size, config.hidden_dim] layers = nn.ModuleList([ TransformerLayer(*args, layer_i, **kwargs) for layer_i in range(config.num_layers) ]) return layers @property def padding_idx(self): return self.dataset.padding_idx @property def eos_idx(self): return self.dataset.eos_idx def reset_named_parameters(self, modules): if 'layers' in modules: for layer in self.layers: layer.reset_parameters() if 'embeddings' in modules: self.embedding.reset_parameters() def forward(self, batch, global_mask=None): # pylint:disable=arguments-differ ''' batch: length x bsz''' batch = batch.transpose(1, 0) targets = left_shift(batch) decoded = self.decode(right_shift(batch), global_mask=global_mask) state = decoded['state'] if not self.adaptive: logits = self.embedding(state, reverse=True).transpose(2, 1) dims = list(range(1, logits.dim())) nll = self.cross_entropy(logits, targets).view(-1) smoothed_nll = self.label_smoothing(logits, targets).sum(dims) if not self.config.return_rank: return smoothed_nll, nll else: logits = logits.transpose(2, 1) assert targets.shape[0] == 1 targets = targets.squeeze(0) target_logits = logits[:, range(targets.shape[0]), targets] rank = (logits > target_logits.unsqueeze(-1)).sum(dim=-1) return rank, nll else: if self.config.batch_length < state.size(1): state = state[:, -self.config.batch_length:].contiguous() targets = targets[:, -self.config.batch_length:].contiguous() state = state.view(-1, state.shape[-1]) # (bsz*L, embed_dim) targets = targets.contiguous().view(-1) # (bsz*L, ) if not self.config.return_rank: nll = self.adaptive_softmax(state, targets, keep_order=True) smoothed_nll = nll return smoothed_nll, nll else: nll, rank = self.adaptive_softmax(state, targets, keep_order=True, return_rank=True) return rank, nll def decode(self, batch, cache=None, global_mask=None): ''' if targest is not None, ''' bsz, L = batch.shape word_embedding = self.embed(batch, self.embedding) decoded = { 'state': word_embedding, } decoded['state'][batch == self.padding_idx] = 0 for i, decoder in enumerate(self.layers): decoded = decoder(decoded, layer_i=i, global_mask=global_mask) return { 'state': decoded['state'], # bs x L x hidden_dim } def embed(self, inputs, token_embedding): ''' Embed the given inputs ''' return self.dropout( token_embedding(inputs) + self.position_embedding(inputs))