class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def forward(self, src_tokens, src_lengths): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.layer_norm: x = self.layer_norm(x) # exam source tokens ''' print('src_tokens') print(src_tokens[:5]) print('src_dict') print(self.dictionary.string(src_tokens[:5])) ''' return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T #-------------------------------------------------------- #'src_tokens':self.dictionary.string(src_tokens), } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}") version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.agg_method = args.agg_method self.agg_layers = args.agg_layers self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) self.attn = MultiheadAttention(embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, encoder_decoder_attention=True) self.fc = Linear(args.agg_layers * embed_dim, embed_dim) self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu')) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, 'relu_dropout', 0) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def forward(self, src_tokens, src_lengths): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers prev_x = [] for idx, layer in enumerate(self.layers): x = layer(x, encoder_padding_mask) # if idx != len(self.layers)-1: prev_x.append(x) # history = torch.cat(prev_x, 2) # history = self.activation_fn(self.fc(history)) # history = F.dropout(history, p=self.activation_dropout, training=self.training) # x, _ = self.attn(query=history, key=x, value=x, key_padding_mask=encoder_padding_mask) prev_x = prev_x[-1 - self.agg_layers:-1] if self.agg_method == 'add': for his in prev_x: x += his elif self.agg_method == 'fc': his = self.activation_fn(self.fc(torch.cat(prev_x, 2))) his = F.dropout(his, p=self.activation_dropout, training=self.training) x, _ = self.attn(query=his, key=x, value=x, key_padding_mask=encoder_padding_mask) elif self.agg_method == 'attn': his = prev_x[0] for idx in range(1, len(prev_x)): his, _ = self.attn(query=his, key=prev_x[idx], value=prev_x[idx], key_padding_mask=encoder_padding_mask) his = F.dropout(his, p=self.dropout, training=self.training) x, _ = self.attn(query=his, key=x, value=x, key_padding_mask=encoder_padding_mask) x = F.dropout(x, p=self.dropout, training=self.training) if self.layer_norm: x = self.layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def reorder_encoder_input(self, encoder_input, new_order): """ Reorder encoder input according to *new_order*. Args: encoder_input: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_input* rearranged according to *new_order* """ if encoder_input['src_tokens'] is not None: encoder_input['src_tokens'] = \ encoder_input['src_tokens'].index_select(0, new_order) return encoder_input def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerEncoderAug(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens,lm): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.src_lm = lm self.acl_drop = args.sca_drop self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.ln1 = Linear(2 *embed_dim, embed_dim) self.ln2 = Linear(2*len(dictionary), len(dictionary)) self.ln3 = Linear(embed_dim, embed_dim,bias = False) def forward(self, src_tokens, src_tokens_lm, src_lengths ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ ## compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) mask = encoder_padding_mask.eq(0).type(torch.FloatTensor).unsqueeze(dim = -1).cuda() if not encoder_padding_mask.any(): mask = None encoder_padding_mask = None src_lm = src_tokens_lm #with torch.no_grad(): if self.training: prop = self.acl_drop word_drop = (torch.rand(src_tokens.size()) > prop).type(torch.LongTensor) lm_drop = word_drop.eq(0).type(torch.FloatTensor) lm_drop = lm_drop.unsqueeze(dim = -1) lm_ = src_lm * lm_drop.cuda() x_ = src_tokens * word_drop.cuda() x1 = self.embed_tokens(x_) x2 = F.linear(lm_,self.embed_tokens.weight.t()) x = self.embed_scale *(x1+x2) else: x = self.embed_scale * self.embed_tokens(src_tokens) x3 = F.linear(src_tokens_lm, self.src_lm.embed_tokens.weight.t()) x3 = F.dropout(x3, p = 0.1, training =self.training) x = self.ln1(torch.cat((x,x3),dim = -1)) #x = x + x4 if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.layer_norm: x = self.layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}") version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layer_wise_attention = getattr(args, 'layer_wise_attention', False) self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args, layer_id=i) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None if getattr(args, 'layernorm_embedding', False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None def forward_embedding(self, src_tokens): # embed tokens and positions embed = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding: x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) return x, embed def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False, **unused): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ if self.layer_wise_attention: return_all_hiddens = True x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) if not self.training or (dropout_probability > self.encoder_layerdrop): x = layer(x, encoder_padding_mask) if return_all_hiddens: encoder_states.append(x) if self.layer_norm: x = self.layer_norm(x) if return_all_hiddens: encoder_states[-1] = x return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] ) def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out.encoder_out is not None: encoder_out = encoder_out._replace( encoder_out=encoder_out.encoder_out.index_select(1, new_order)) if encoder_out.encoder_padding_mask is not None: encoder_out = encoder_out._replace( encoder_padding_mask=encoder_out.encoder_padding_mask. index_select(0, new_order)) if encoder_out.encoder_embedding is not None: encoder_out = encoder_out._replace( encoder_embedding=encoder_out.encoder_embedding.index_select( 0, new_order)) if encoder_out.encoder_states is not None: for idx, state in enumerate(encoder_out.encoder_states): encoder_out.encoder_states[idx] = state.index_select( 1, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_( dim, dim)), 1) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: print('deleting {0}'.format(weights_key)) del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerAvgEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None #image section self.img_dim = 2048 self.text_dim = embed_dim self.L2norm = args.L2norm self.total_num_img = args.total_num_img self.per_num_img = args.per_num_img # cap2image_file = args.cap2image_file # image_embedding_file = args.image_embedding_file cap2image_file = getattr(args, "cap2image_file", "data/cap2image.pickle") image_embedding_file = getattr(args, "image_embedding_file", "features_resnet50/train-resnet50-avgpool.npy") self.cap2image = pickle.load(open(cap2image_file, "rb")) #cap_id to image_id #print("image embedding processing...") embeding_weights = np.load(image_embedding_file) img_vocab, img_dim = embeding_weights.shape embeddings_matrix = np.zeros((img_vocab + 1, img_dim)) embeddings_matrix[1:] = embeding_weights self.img_embeddings = nn.Embedding.from_pretrained(torch.FloatTensor(embeddings_matrix), freeze=args.image_emb_fix) # update embedding # self.img_embeddings.load_state_dict({'weight': embeddings_matrix}) # if args.image_emb_fix: # self.img_embeddings.weight.requires_grad = False self.merge_option = args.merge_option self.dense = nn.Linear(self.img_dim, self.text_dim) self.mergeImage = nn.Linear(self.total_num_img, 1) if self.merge_option == "att-mul-concat": self.proj_attention = SCAttention(self.text_dim, 128) self.dense2 = nn.Linear(self.text_dim, 384) elif self.merge_option == "att-concat": self.dense2 = nn.Linear(2 * self.text_dim, self.text_dim) elif self.merge_option == "att-gate": self.gate_type = args.gate_type self.proj_attention = SCAttention(self.text_dim, self.text_dim) if self.gate_type == "neural-gate": self.sigmoid = nn.Sigmoid() self.gate_dense = nn.Linear(2*self.text_dim, self.text_dim) elif self.gate_type == "scalar-gate": self.sigmoid = nn.Sigmoid() self.gate_dense = nn.Linear(2*self.text_dim, 1) else: self.image_weight = args.image_weight else: self.proj_attention = SCAttention(self.text_dim, self.text_dim) def forward(self, src_tokens, src_lengths): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ srl_tok_list = src_tokens.tolist() batch_image_ids = [] for batch_idx, sent in enumerate(srl_tok_list): # token2image image_ids = [] for cap in sent: if cap in self.cap2image: for id in self.cap2image[cap][:self.per_num_img]: if id != 0: image_ids.append(id) image_freq= Counter(image_ids) image_sort = sorted(image_freq.items(), key=lambda x: x[1], reverse=True) image_ids = [item[0] for idx, item in enumerate(image_sort) if idx < self.total_num_img] #image_ids = image_ids[:self.total_num_img] # Zero-pad up to the sequence length. padding_length = self.total_num_img - len(image_ids) image_ids = image_ids + ([0] * padding_length) batch_image_ids.append(image_ids) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_image_ids = torch.LongTensor(batch_image_ids).to(device) #image embedding batch_size, num_img = batch_image_ids.size() #print(batch_image_ids[0]) image_padding_mask = batch_image_ids.eq(0) #print(image_padding_mask[0]) image_mask = ~image_padding_mask # print(image_mask[0]) # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask # print(src_tokens) encoder_padding_mask = src_tokens.eq(self.padding_idx) text_mask = ~encoder_padding_mask #print(encoder_padding_mask.size()) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.layer_norm: x = self.layer_norm(x) # image_ids_flat = image_ids.view(batch_size, -1) image_ids_flat = batch_image_ids #batch_size x num_image image_embedding = self.img_embeddings(image_ids_flat) image_embedding = image_embedding.view(batch_size, num_img, self.img_dim) # batch_size, num_img, dim #L2 norm if self.L2norm == "true": image_embedding = F.normalize(image_embedding, p=2, dim=1) # attention on each local region text_repr = x.transpose(0, 1) # T x B x C -> batch_size, seq_len, dim image_repr = self.dense(image_embedding) # batch_size, num_img, image_dim - > text_dim if self.merge_option == "biatt": output = self.proj_attention(image_repr, image_mask, text_repr, text_mask) output = self.proj_attention(text_repr, text_mask, output, image_mask) #batch_size, seq_len, dim elif self.merge_option == "att": output = self.proj_attention(text_repr, text_mask, image_repr, image_mask) #batch_size, seq_len, dim elif self.merge_option == "att-sum": output = self.proj_attention(text_repr, text_mask, image_repr, image_mask) # batch_size, seq_len, dim output = text_repr + output #0.5, 1 elif self.merge_option == "att-gate": output = self.proj_attention(text_repr, text_mask, image_repr, image_mask) # batch_size, seq_len, dim if self.gate_type == "neural-gate": merge = torch.cat([text_repr, output], dim=-1) gate = self.sigmoid(self.gate_dense(merge)) output = (1 - gate)*text_repr + gate*output #print("neural-gate") elif self.gate_type == "scalar-gate": merge = torch.cat([text_repr, output], dim=-1) gate = self.sigmoid(self.gate_dense(merge)) output = (1 - gate)*text_repr + gate*output #print("scalar-gate") else: output = 1.0*text_repr + self.image_weight*output else: output = None #print(output.size()) x = output.transpose(0, 1) # batch_size, seq_len, dim -> T x B x C #print(x) # print(x.size()) # print(encoder_padding_mask.size()) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i)) version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class JointAttentionEncoder(FairseqEncoder): """ JointAttention encoder is used only to compute the source embeddings. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding left_pad (bool): whether the input is left-padded """ def __init__(self, args, dictionary, embed_tokens, left_pad): super().__init__(dictionary) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.embed_language = LanguageEmbedding( embed_dim) if args.language_embeddings else None self.register_buffer('version', torch.Tensor([2])) def forward(self, src_tokens, src_lengths): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): embedding output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) # language embedding if self.embed_language is not None: lang_emb = self.embed_scale * self.embed_language.view(1, 1, -1) x += lang_emb x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions())
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens, max_source_positions, encoder_layers, encoder_embed_dim, encoder_attention_heads, encoder_ffn_embed_dim,): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout if embed_tokens is None: self.padding_idx = 0 embed_dim = encoder_embed_dim else: self.padding_idx = embed_tokens.padding_idx embed_dim = embed_tokens.embedding_dim self.max_source_positions = max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layer_wise_attention = getattr(args, 'layer_wise_attention', False) self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(encoder_embed_dim, encoder_attention_heads, encoder_ffn_embed_dim, args) for i in range(encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.use_seg_pos_emb = getattr(args, 'use_seg_pos_emb', 0) if self.use_seg_pos_emb: self.seg_pad_idx = 2 self.seg_pos_emb = Embedding(3, embed_dim, padding_idx=self.seg_pad_idx) def forward_embedding(self, src_tokens, seg_pos=-1): # embed tokens and positions embed = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.use_seg_pos_emb: masked_src_tokens = src_tokens.masked_fill(src_tokens.ne(self.padding_idx), seg_pos) masked_src_tokens = masked_src_tokens.masked_fill(src_tokens.eq(self.padding_idx), self.seg_pad_idx) x = x + self.embed_scale * self.seg_pos_emb(masked_src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) return x, embed def forward(self, src_tokens=None, cls_input=None, return_all_hiddens=False, src_encodings=None, encoder_padding_mask=None, attn_mask=None, auxilary_tokens=None): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` src_encodings (torch.FloatTensor): shape of `(T x B x C)` encoder_padding_mask (torch.Boolean): shape of '(B x T)', where paddings are True return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ if self.layer_wise_attention: return_all_hiddens = True if self.embed_tokens is not None: x, encoder_embedding = self.forward_embedding(src_tokens, seg_pos=0) aug_x, _ = self.forward_embedding(auxilary_tokens, seg_pos=1) x = torch.cat([x, aug_x], dim=1) # B x T x C -> T x B x C x = x.transpose(0, 1) src_tokens = torch.cat([src_tokens, auxilary_tokens], dim=1) # compute padding mask encoder_padding_mask = (src_tokens.eq(self.padding_idx) | src_tokens.eq(self.dictionary.bos_index)) else: assert encoder_padding_mask is not None src_tokens = encoder_padding_mask.long() encoder_embedding = None x = src_encodings if self.embed_positions is not None: x = x + self.embed_positions(src_tokens).transpose(0, 1) x = F.dropout(x, p=self.dropout, training=self.training) encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask, attn_mask=attn_mask) if return_all_hiddens: encoder_states.append(x) if self.layer_norm: x = self.layer_norm(x) if return_all_hiddens: encoder_states[-1] = x return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_embedding': encoder_embedding, # B x T x C 'encoder_states': encoder_states, # List[T x B x C] } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) if encoder_out.get('encoder_states', None) is not None: for idx, state in enumerate(encoder_out['encoder_states']): encoder_out['encoder_states'][idx] = state.index_select(1, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i)) version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TaLKConvEncoder(FairseqEncoder): """ Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TaLKConvEncoderLayer(args, kernel_size=args.encoder_kernel_size_list[i]) for i in range(args.encoder_layers) ]) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.encoder_normalize_before if self.normalize: self.layer_norm = LayerNorm(embed_dim) self.acts_reg = [] def forward(self, src_tokens, **unused): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.normalize: x = self.layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions())
class LightConvEncoder(FairseqEncoder): """ LightConv encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`LightConvEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.encoder_embed_dim = args.encoder_embed_dim #self.bi_rnn_layer = torch.nn.GRU( # args.encoder_embed_dim, # args.encoder_embed_dim, # num_layers=1, # batch_first=True, # bidirectional=True #) self.rnn_layer = torch.nn.GRU(args.encoder_embed_dim, args.encoder_embed_dim, num_layers=4, dropout=0.1, batch_first=True, bidirectional=False) self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ LightConvEncoderLayer(args, kernel_size=args.encoder_kernel_size_list[i]) for i in range(args.encoder_layers) ]) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.encoder_normalize_before if self.normalize: self.layer_norm = LayerNorm(embed_dim) def forward(self, src_tokens, **unused): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # We are in a character-level settings, let's add some char-rnn # embed tokens and positions src_tokens_temp = self.embed_tokens(src_tokens) #bi_output, _ = self.bi_rnn_layer(src_tokens_temp) # Concatenate (PLUS) the bidirectional encoder rnn outputs so that we have the original embedding size #src_tokens_temp = bi_output[:, :, :self.encoder_embed_dim] + bi_output[:, :, self.encoder_embed_dim:] # Average values #src_tokens_temp = src_tokens_temp / 2.0 src_tokens_temp, _ = self.rnn_layer(src_tokens_temp) x = self.embed_scale * src_tokens_temp if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers for layer in self.layers: # TODO(naetherm):residual_x = x x = layer(x, encoder_padding_mask) # TODO(naetherm):x = residual_x + x if self.normalize: x = self.layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions())
class TransformerEncoderC(nn.Module): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Controller stream Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, vocab, embed_tokens): super().__init__() self.vocab = vocab self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayerC(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.g_layer_norm = LayerNorm(embed_dim) else: self.g_layer_norm = None def forward(self, src_tokens, encoder_mode="gumbel", encoder_temperature=-1, need_weights=False, **unused): # embed tokens and positions embedding = self.embed_tokens(src_tokens) g = self.embed_scale * embedding g += self.embed_positions(src_tokens) g = F.dropout(g, p=self.dropout, training=self.training) # B x T x C -> T x B x C g = g.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers attn_list = [] attn_data_list = [] for layer in self.layers: g, attn, attn_data = layer( g, encoder_padding_mask=encoder_padding_mask, encoder_mode=encoder_mode, encoder_temperature=encoder_temperature, need_weights=need_weights) attn_list.append(attn) attn_data_list.append(attn_data) if self.g_layer_norm: g = self.g_layer_norm(g) return { 'encoder_g': g, # T x B x C 'encoder_attn': attn_list, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_attn_data_list': attn_data_list } def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions())
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.add_template = args.add_template if self.add_template: self.template_layers = nn.ModuleList([]) self.template_layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.tp_layer_norm = LayerNorm(embed_dim) else: self.tp_layer_norm = None self.positionwise = PositionWise(embed_dim, embed_dim, self.dropout) self.two_encoder_mix = nn.Linear(2 * embed_dim, embed_dim) self.attention = MultiheadAttention(embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout) def forward(self, src_tokens, src_lengths, template=None): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.layer_norm: x = self.layer_norm(x) if self.add_template: tp = self.embed_scale * self.embed_tokens(template) if self.embed_positions is not None: tp += self.embed_positions(template) tp = F.dropout(tp, p=self.dropout, training=self.training) # B x T x C -> T x B x C tp = tp.transpose(0, 1) # compute padding mask tp_encoder_padding_mask = template.eq(self.padding_idx) # encoder layers for layer in self.template_layers: tp = layer(tp, tp_encoder_padding_mask) if self.tp_layer_norm: tp = self.tp_layer_norm(tp) adj_att, _ = self.attention( query=x, key=tp, value=tp, key_padding_mask=tp_encoder_padding_mask) adj_att = F.dropout(adj_att, p=self.dropout, training=self.training) adj_egd_cat = torch.cat([adj_att, x], dim=-1) two_encoder = self.two_encoder_mix(adj_egd_cat) gate = torch.sigmoid(two_encoder) output = gate.mul(adj_att) + (1 - gate).mul(x) x = self.positionwise(output) if encoder_padding_mask is not None: mean_mask = 1 - encoder_padding_mask mean_mask = mean_mask.unsqueeze(2).repeat(1, 1, x.size()[2]).transpose( 0, 1).float() adj_att_mean = mean_mask * adj_att adj_att_mean = torch.mean(adj_att_mean, dim=0) else: adj_att_mean = torch.mean(adj_att, dim=0) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T 'tp_mean': adj_att_mean, } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SimMTTransformerMultiPassEncoder(FairseqEncoder): """ SimMTTransformerMultiPass encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayerOurs`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayerOurs(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.wait_k = args.wait_k def forward(self, src_tokens, src_lengths): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) # unfold, forward T' pass # pdb.set_trace() t = x.shape[0] t_fw = max(t - self.wait_k + 1, 1) # T': t forward # padding mask encoder_padding_mask = encoder_padding_mask.unsqueeze(1).repeat(1, t_fw, 1) # B x T => B x T' x T # mask time time_mask = (torch.arange(t)[None, :] > torch.arange(t)[:, None]).to(x.device)[min(self.wait_k, t) - 1:, :] # T' x T encoder_padding_mask = (encoder_padding_mask | time_mask[None, :, :]) # B x T' x T encoder_padding_mask = encoder_padding_mask.view(-1, t) # (B * T') x T # feature x = x.unsqueeze(2).repeat(1, 1, t_fw, 1) # T x B x C => T x B x T' x C x = x.view(t, -1, x.shape[-1]) # T x (B * T') x C # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.layer_norm: x = self.layer_norm(x) # for each sample, expand to a long vector x = x.view(t, -1, t_fw, x.shape[-1]).permute(2, 0, 1, 3) # T x B x T' x C => T' x T x B x C x = x[~time_mask, :, :] # Tfinal x B x C forward_idx = torch.arange(t_fw).to(x.device).unsqueeze(1).repeat(1, t)[~time_mask] # Tfinal, indicating which token belongs to which forward encoder_padding_mask = encoder_padding_mask.view(-1, t_fw, t)[:, ~time_mask] return { 'encoder_out': x, 'encoder_padding_mask': encoder_padding_mask, 'forward_idx': forward_idx, 'src_lengths': src_lengths, } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ # pdb.set_trace() if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) encoder_out['src_lengths'] = encoder_out['src_lengths'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i)) version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.output_embed_dim = args.decoder_output_dim padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt( embed_dim) # todo: try with input_embed_dim # calculate copy probability p(z=1) batch self.copy = args.copy self.project_in_dim = Linear( input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) if self.copy: self.copy_attn = MultiheadAttention( embed_dim, 1, dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.linear_copy = Linear(embed_dim, 1) self.adaptive_softmax = None self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \ if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), self.output_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5) if args.decoder_normalize_before and not getattr( args, 'no_decoder_final_norm', False): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state) x = self.output_layer(x) return x, extra def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] # decoder layers for layer in self.layers: x, attn = layer( x, encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, ) inner_states.append(x) if self.layer_norm: x = self.layer_norm(x) copy_x, copy_attn = None, None if self.copy: copy_x, copy_attn = self.copy_attn( query=x, key=encoder_out['encoder_out'] if encoder_out is not None else None, value=encoder_out['encoder_out'] if encoder_out is not None else None, key_padding_mask=encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state=incremental_state, static_kv=True, need_weights=True, ) # copy_x = copy_x.transpose(0, 1) p_copy = None if self.copy: # p_copy = torch.sigmoid(self.linear_copy(copy_attn)) p_copy = torch.sigmoid(self.linear_copy(x)).transpose(0, 1) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) # return x, {'attn': attn, 'inner_states': inner_states, 'p_copy': p_copy} return x, { 'attn': attn, 'inner_states': inner_states, 'p_copy': p_copy, 'copy_attn': copy_attn } def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: return F.linear(features, self.embed_tokens.weight) else: return F.linear(features, self.embed_out) else: return features def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] is_copy = 'p_copy' in net_output[1].keys( ) and net_output[1]['p_copy'] is not None # print(net_output[1]['attn']) if is_copy and False: p_copy = net_output[1]['p_copy'] if 'net_input' in sample.keys(): enc_seq_ids = sample['net_input']['src_tokens'] else: # for decode step enc_seq_ids = sample['src_tokens'] enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat( 1, net_output[1]['copy_attn'].size(1), 1) generate_prob = utils.softmax( logits, dim=-1, onnx_trace=self.onnx_trace) * (1 - p_copy) copy_prob = net_output[1]['copy_attn'] * p_copy final = generate_prob.scatter_add(2, enc_seq_ids, copy_prob) if log_probs: return torch.log(final + 1e-15) else: return final else: if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size( 0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms layer_norm_map = { '0': 'self_attn_layer_norm', '1': 'encoder_attn_layer_norm', '2': 'final_layer_norm' } for old, new in layer_norm_map.items(): for m in ('weight', 'bias'): k = '{}.layers.{}.layer_norms.{}.{}'.format( name, i, old, m) if k in state_dict: state_dict['{}.layers.{}.{}.{}'.format( name, i, new, m)] = state_dict[k] del state_dict[k] version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class JointAttentionDecoder(FairseqIncrementalDecoder): """ JointAttention decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`ProtectedTransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding left_pad (bool, optional): whether the input is left-padded. Default: ``False`` """ def __init__( self, args, dictionary, embed_tokens, left_pad=False, final_norm=True): super().__init__(dictionary) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed self.kernel_size_list = args.kernel_size_list input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim output_embed_dim = args.decoder_output_dim padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.project_in_dim = Linear( input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.embed_language = LanguageEmbedding( embed_dim) if args.language_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ ProtectedTransformerDecoderLayer(args, no_encoder_attn=True) for _ in range(args.decoder_layers) ]) self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \ if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None if not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor( len(dictionary), output_embed_dim)) nn.init.normal_( self.embed_out, mean=0, std=output_embed_dim ** -0.5) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.decoder_normalize_before and final_norm if self.normalize: self.layer_norm = LayerNorm(embed_dim) # self.skipped_layer = 0 max_level = 10 step = 1. / float(max_level) levels = [round(x * step, 2) for x in range(1, max_level + 1)] # automate self.skip_layers = self.jump_or_not(levels) # manual set # self.skip_layers = self.jump_or_not_manual() self.layer_drop_rate = 0. # self.scaling = False self.Formula = False self.source_target_drop = False print("SKip Layers :", self.skip_layers) print("layer_drop_rate:%f" % self.layer_drop_rate) def jump_or_not_manual(self): # manual all = {0.33: [0, 1, 4, 8, 12], 0.66: [0, 1, 2, 4, 6, 8, 10, 12], 1.0: [x for x in range(0, 14)]} # all = {0.33: [0, 1, 4, 8, 12], 0.66: [0, 1, 3, 5, 6, 8, 10, 12], 1.0: [x for x in range(0, 14)]} # all = {0.33: [0, 1, 3, 8], 0.66: [0, 1, 2, 4, 5, 8, 10], 1.0: [x for x in range(0, 14)]} return all def jump_or_not(self, levels): all = {} for level in levels: last_target_layer = int(len(self.layers) * level) step = round(len(self.layers) / last_target_layer) skip_layers = [] for i in range(0, len(self.layers), step): skip_layers.append(i) # append last layer if skip_layers[-1] != len(self.layers) - 1: skip_layers.append(len(self.layers) - 1) all[level] = skip_layers return all def base_jump_or_not(self): # training time # if i == last_target_layer: # break # inference only time, 3 losses # if not self.training and i == last_target_layer: # return True # layerdrop-method # r = random.random() return True def layer_drop(self, i): p = self.layer_drop_rate # i += 1 n = random.random() if self.Formula: pl = float((i / len(self.layers)) * (1. - p)) else: pl = p return n <= pl def scale_whole_layer(self, i): i += 1 p = self.layer_drop_rate # 2016 paper, scale down pl = 1 - float((i / len(self.layers)) * (1. - p)) # 2019 paper Speech # pl = (i / len(self.layers)) * (1 - p) # pl = 1 / (1 - pl) return pl def forward( self, prev_output_tokens, encoder_out, incremental_state=None, level=1.): """ Args: input (dict): with prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the last decoder layer's output of shape `(batch, tgt_len, vocab)` - the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)` """ tgt_len = prev_output_tokens.size(1) # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions # language embedding if self.embed_language is not None: lang_emb = self.embed_scale * self.embed_language.view(1, 1, -1) x += lang_emb x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] source = encoder_out['encoder_out'] process_source = incremental_state is None or len( incremental_state) == 0 # extended padding mask source_padding_mask = encoder_out['encoder_padding_mask'] if source_padding_mask is not None: target_padding_mask = source_padding_mask.new_zeros( (source_padding_mask.size(0), tgt_len)) self_attn_padding_mask = torch.cat( (source_padding_mask, target_padding_mask), dim=1) else: self_attn_padding_mask = None # inference time if 'level' in encoder_out: level = encoder_out['level'] # fix all the batches's level to be easy # level = 0.33 # transformer layers for i, layer in enumerate(self.layers): # training with dropout - normal way if self.training and self.layer_drop(i): continue # skipping inference # if not self.training and i not in self.skip_layers[level]: # continue # if self.kernel_size_list is not None: target_mask = self.local_mask( x, self.kernel_size_list[i], causal=True, tgt_len=tgt_len) elif incremental_state is None: target_mask = self.buffered_future_mask(x) else: target_mask = None if target_mask is not None: zero_mask = target_mask.new_zeros( (target_mask.size(0), source.size(0))) self_attn_mask = torch.cat((zero_mask, target_mask), dim=1) else: self_attn_mask = None # if self.source_target_drop and not self.layer_drop(i) or i == 0 or not self.training: state = incremental_state if process_source: if state is None: state = {} if self.kernel_size_list is not None: source_mask = self.local_mask( source, self.kernel_size_list[i], causal=False) else: source_mask = None source, attn = layer( source, None, None, state, self_attn_mask=source_mask, self_attn_padding_mask=source_padding_mask ) inner_states.append(source) # if self.source_target_drop and not self.layer_drop(i) or not self.training: x, attn = layer( x, None, None, state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask ) # x scaling if self.scaling: # training x = x * self.scale_whole_layer(i) inner_states.append(x) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) # project back to size of vocabulary if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = F.linear(x, self.embed_out) pred = x info = {'attn': attn, 'inner_states': inner_states} return pred, info # # 3 loss ways in layer 4,9,14 -> 0 + 8 + 10 + 10 == 8,18,28 # if self.training: # step = [8, 18, 28] # # step = [28] # else: # step = [len(inner_states) - 1] # loss_output = [] # for i in step: # x = inner_states[i] # # # if self.normalize: # x = self.layer_norm(x) # # # T x B x C -> B x T x C # x = x.transpose(0, 1) # # if self.project_out_dim is not None: # x = self.project_out_dim(x) # # # project back to size of vocabulary # if self.share_input_output_embed: # x = F.linear(x, self.embed_tokens.weight) # else: # x = F.linear(x, self.embed_out) # # pred = x # info = {'attn': attn, 'inner_states': inner_states} # loss_output.append((pred, info)) # # # return pred, info # if not self.training and len(loss_output) == 1: # return loss_output[0] # return loss_output def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min( self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): """Cached future mask.""" dim = tensor.size(0) # pylint: disable=access-member-before-definition, # attribute-defined-outside-init if not hasattr( self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf( tensor.new( dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf( self._future_mask.resize_( dim, dim)), 1) return self._future_mask[:dim, :dim] def local_mask(self, tensor, kernel_size, causal, tgt_len=None): """Locality constraint mask.""" rows = tensor.size(0) cols = tensor.size(0) if tgt_len is None else tgt_len if causal: if rows == 1: mask = utils.fill_with_neg_inf(tensor.new(1, cols)) mask[0, -kernel_size:] = 0 return mask else: diag_u, diag_l = 1, kernel_size else: diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) // 2) if kernel_size % 2 == 1 else (kernel_size // 2, kernel_size // 2 + 1) mask1 = torch.triu( utils.fill_with_neg_inf( tensor.new( rows, cols)), diag_u) mask2 = torch.tril( utils.fill_with_neg_inf( tensor.new( rows, cols)), -diag_l) return mask1 + mask2
class TransformerDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). final_norm (bool, optional): apply layer norm to the output of the final decoder layer (default: True). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True): super().__init__(dictionary) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.output_embed_dim = args.decoder_output_dim padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt( embed_dim) # todo: try with input_embed_dim self.project_in_dim = Linear( input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.adaptive_softmax = None self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \ if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), self.output_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.decoder_normalize_before and final_norm if self.normalize: self.layer_norm = LayerNorm(embed_dim) self.onnx_trace = False self.decoder_max_order = args.decoder_max_order self.clamp_value = getattr(args, 'clamp_value', 0.01) self.gs_clamp = args.gs_clamp def set_perm_order(self, perm_order=0): assert isinstance(perm_order, int) and 0 <= perm_order <= 5 for layer in self.layers: layer.set_perm_order(perm_order) def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state) x = self.output_layer(x, encoder_out) return x, extra def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] # decoder layers for layer in self.layers: x, attn = layer( x, encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, ) inner_states.append(x) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) return x, {'attn': attn, 'inner_states': inner_states} def output_layer(self, features, encoder_out, **kwargs): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: return [ F.linear(features, self.embed_tokens.weight), encoder_out['encoder_pred_order'] ] else: return F.linear(features, self.embed_out) else: return features def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms layer_norm_map = { '0': 'self_attn_layer_norm', '1': 'encoder_attn_layer_norm', '2': 'final_layer_norm' } for old, new in layer_norm_map.items(): for m in ('weight', 'bias'): k = '{}.layers.{}.layer_norms.{}.{}'.format( name, i, old, m) if k in state_dict: state_dict['{}.layers.{}.{}.{}'.format( name, i, new, m)] = state_dict[k] del state_dict[k] if utils.item( state_dict.get('{}.version'.format(name), torch.Tensor( [1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict['{}.version'.format(name)] = torch.Tensor([1]) return state_dict def get_normalized_probs(self, net_output, log_probs, sample, gs_tau=0.5, gs_hard=False): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0][0] orders = net_output[0][1] if log_probs: return (utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace), *self.gumbel_softmax( orders, gs_tau=gs_tau, gs_hard=gs_hard, dim=-1)) else: return (utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace), *self.gumbel_softmax( orders, gs_tau=gs_tau, gs_hard=gs_hard, dim=-1)) def gumbel_softmax(self, logits, gs_tau=0.5, gs_hard=False, dim=-1): if not gs_hard: prob = utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) prob_clamp = torch.clamp( prob, self.clamp_value, 1. - (self.decoder_max_order - 1) * self.clamp_value) logprob = torch.log(prob_clamp if self.gs_clamp else prob) gs = F.gumbel_softmax( logprob, tau=gs_tau, hard=False, ) else: prob = utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) prob_clamp = torch.clamp( prob, self.clamp_value, 1. - (self.decoder_max_order - 1) * self.clamp_value) max_idx = torch.argmax(logits, -1, keepdim=True) one_hot = logits.new_zeros(logits.size()) gs = one_hot.scatter(-1, max_idx, 1) return gs, prob, prob_clamp
class HierarchicalTransformerEncoder(FairseqEncoder): """ hierarchical_transformer_sent encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`hierarchical_transformer_sentEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim / 2, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.embed_positions2 = PositionalEmbedding( args.max_source_positions, embed_dim / 4, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend( [TransformerLayer(args) for _ in range(args.encoder_layers)]) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.encoder_normalize_before if self.normalize: self.layer_norm = LayerNorm(embed_dim) self.sentence_norm = LayerNorm(embed_dim) self.doc_norm = LayerNorm(embed_dim) def forward(self, src_tokens, src_lengths, block_mask, doc_lengths, doc_block_mask): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, n_blocks, n_tokens)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` block_mask (torch.LongTensor): block mask of the source sentences of shape `(batch, n_blocks, n_blocks)` doc_lengths (torch.LongTensor): doc mask of the source sentences of shape `(batch)` doc_block_mask (torch.LongTensor): doc mask of the source sentences of shape `(batch, n_docs, n_blocks)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions batch_size, n_blocks, n_tokens = src_tokens.size() doc_padding_mask = torch.arange(0, doc_lengths.max()) doc_padding_mask = doc_padding_mask.repeat(doc_lengths.numel(), 1) doc_padding_mask = 1 - doc_padding_mask.lt( doc_lengths.unsqueeze(1).cpu()) doc_padding_mask = doc_padding_mask.byte().cuda() n_docs = doc_padding_mask.size(1) x = self.embed_scale * self.embed_tokens(src_tokens) # if self.embed_positions is not None: local_pos_emb = self.embed_positions( src_tokens.view(batch_size * n_blocks, n_tokens)) local_pos_emb = local_pos_emb.view(batch_size, n_blocks, n_tokens, -1) def collate_embedding(values, pad_idx, size): """Convert a list of 2d tensors into a padded 3d tensor.""" # size = max(v.size(0) for v in values) res = values[0][0, 0].new(len(values), size, values[0].size(1)).fill_(pad_idx) def copy_tensor(src, dst): assert dst.numel() == src.numel() dst.copy_(src) for i, v in enumerate(values): copy_tensor(v[:min(v.size(0), size)], res[i][:v.size(0)]) return res doc_sentence_lengths = torch.sum(doc_block_mask, 2) # (batch, n_docs) block_pos_emb = self.embed_positions2(torch.sum( src_tokens, 2)) # (batch, n_blocks, embed_dim) block_pos_emb = collate_embedding([ torch.cat([ block_pos_emb[i, :doc_sentence_lengths[i, j]] for j in range(n_docs) if doc_sentence_lengths[i, j] != 0 ], 0) for i in range(block_pos_emb.size(0)) ], 0, n_blocks) # (batch, n_blocks, embed_dim) block_pos_emb = block_pos_emb.unsqueeze(2).repeat(1, 1, n_tokens, 1) def collate_embedding(values, pad_idx, size): """Convert a list of 2d tensors into a padded 3d tensor.""" # size = max(v.size(0) for v in values) res = values[0][0, 0].new(len(values), size, values[0].size(1)).fill_(pad_idx) def copy_tensor(src, dst): assert dst.numel() == src.numel() dst.copy_(src) for i, v in enumerate(values): copy_tensor(v[:min(v.size(0), size)], res[i][:v.size(0)]) return res doc_pos_emb = self.embed_positions2( doc_sentence_lengths) # (batch, n_docs, embed_dim) doc_pos_emb = collate_embedding([ torch.cat([ doc_pos_emb[i, j].unsqueeze(0).repeat( doc_sentence_lengths[i, j], 1) for j in range(doc_pos_emb.size(1)) if doc_sentence_lengths[i, j] != 0 ], 0) for i in range(doc_pos_emb.size(0)) ], 0, n_blocks) # (batch, n_blocks, embed_dim) doc_pos_emb = doc_pos_emb.unsqueeze(2).repeat(1, 1, n_tokens, 1) combined_pos_emb = torch.cat( [local_pos_emb, block_pos_emb, doc_pos_emb], -1) x += combined_pos_emb x = F.dropout(x, p=self.dropout, training=self.training) # compute padding mask local_padding_mask = src_tokens.eq(self.padding_idx).view( batch_size * n_blocks, n_tokens) block_padding_mask = torch.sum( 1 - local_padding_mask.view(batch_size, n_blocks, n_tokens), -1) == 0 x = x.view(batch_size * n_blocks, n_tokens, -1) # B x T x C -> T x B x C x = x.transpose(0, 1) block_vec = torch.zeros(n_blocks, batch_size, self.embed_tokens.embedding_dim).cuda() doc_vec = torch.zeros(n_docs, batch_size, self.embed_tokens.embedding_dim).cuda() # encoder local layers for layer in self.layers: x, block_vec, doc_vec = layer(x, block_vec, doc_vec, local_padding_mask, block_padding_mask, doc_padding_mask, block_mask, doc_block_mask, batch_size, n_blocks) if self.normalize: x = self.layer_norm(x) block_vec = self.sentence_norm(block_vec) doc_vec = self.doc_norm(doc_vec) # T x B x C -> B x T x C x = x.transpose(0, 1) mask_hier = 1 - local_padding_mask[:, :, None].float() src_features = x * mask_hier src_features = src_features.view(batch_size, n_blocks * n_tokens, -1) src_features = src_features.transpose( 0, 1).contiguous() # src_len, batch_size, hidden_dim mask_hier = mask_hier.view(batch_size, n_blocks * n_tokens, -1) mask_hier = mask_hier.transpose(0, 1).contiguous() unpadded = [ torch.masked_select(src_features[:, i], mask_hier[:, i].byte()).view( [-1, src_features.size(-1)]) for i in range(src_features.size(1)) ] max_l = max([p.size(0) for p in unpadded]) def sequence_mask(lengths, max_len=None): """ Creates a boolean mask from sequence lengths. """ batch_size = lengths.numel() max_len = max_len or lengths.max() return (torch.arange(0, max_len).type_as(lengths).repeat( batch_size, 1).lt(lengths.unsqueeze(1))) mask_hier = sequence_mask(torch.tensor([p.size(0) for p in unpadded]), max_l).cuda() mask_hier = 1 - mask_hier[:, None, :] unpadded = torch.stack([ torch.cat([ p, torch.zeros(max_l - p.size(0), src_features.size(-1)).cuda() ]) for p in unpadded ], 1) x = unpadded # x = unpadded.transpose(0, 1) encoder_padding_mask = mask_hier.squeeze(1) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T 'sentence_out': block_vec, # T x B x C 'sentence_padding_mask': block_padding_mask, # B x T 'doc_out': doc_vec, # T x B x C 'doc_padding_mask': doc_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) if encoder_out['sentence_out'] is not None: encoder_out['sentence_out'] = \ encoder_out['sentence_out'].index_select(1, new_order) if encoder_out['sentence_padding_mask'] is not None: encoder_out['sentence_padding_mask'] = \ encoder_out['sentence_padding_mask'].index_select(0, new_order) if encoder_out['doc_out'] is not None: encoder_out['doc_out'] = \ encoder_out['doc_out'].index_select(1, new_order) if encoder_out['doc_padding_mask'] is not None: encoder_out['doc_padding_mask'] = \ encoder_out['doc_padding_mask'].index_select(0, new_order) return encoder_out def reorder_encoder_input(self, encoder_input, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ # print('reorder') if encoder_input['src_tokens'] is not None: encoder_input['src_tokens'] = \ encoder_input['src_tokens'].index_select(0, new_order) if encoder_input['src_lengths'] is not None: encoder_input['src_lengths'] = \ encoder_input['src_lengths'].index_select(0, new_order) if encoder_input['block_mask'] is not None: encoder_input['block_mask'] = \ encoder_input['block_mask'].index_select(0, new_order) if encoder_input['doc_block_mask'] is not None: encoder_input['doc_block_mask'] = \ encoder_input['doc_block_mask'].index_select(0, new_order) if encoder_input['doc_lengths'] is not None: encoder_input['doc_lengths'] = \ encoder_input['doc_lengths'].index_select(0, new_order) return encoder_input def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}") version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class HybridEncoder(FairseqEncoder): def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.dropout = args.embed_dropout self.embed_tokens = embed_tokens self.embed_dim = embed_tokens.embedding_dim self.embed_scale = math.sqrt(self.embed_dim) self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_positions = PositionalEmbedding( self.max_source_positions, self.embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.encoder_layers = args.encoder_layers self.convlayers = nn.ModuleList([ DynamicConvEncoderLayer(args.encoder_embed_dim, args.encoder_embed_dim, args.encoder_attention_heads, args.encoder_kernel_size_list[i], input_dropout=args.conv_input_dropout, weight_dropout=args.conv_weight_dropout, dropout=args.conv_output_dropout) for i in range(self.encoder_layers) ]) self.attnlayers = nn.ModuleList([ AttentionEncoderLayer(args.encoder_embed_dim, args.encoder_attention_heads, self_attention=True, attention_dropout=args.attn_weight_dropout, dropout=args.attn_output_dropout) for _ in range(self.encoder_layers) ]) self.fflayers = nn.ModuleList([ FFLayer(args.encoder_embed_dim, args.encoder_ffn_embed_dim, relu_dropout=args.ff_relu_dropout, dropout=args.ff_output_dropout) for _ in range(self.encoder_layers) ]) self.ratios = nn.Parameter(torch.FloatTensor(self.encoder_layers, 1), requires_grad=True) self.ratios.data.fill_(0.5) # self.ratios = [nn.Parameter(torch.FloatTensor(1), requires_grad=True).cuda() for _ in range(7)] # for ratio in self.ratios: # ratio.data.fill_(0.5) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.encoder_normalize_before if self.normalize: self.layer_norm = LayerNorm(self.embed_dim) def forward(self, src_tokens, **unused): x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) x = x.transpose(0, 1) encoder_padding_mask = src_tokens.eq(self.padding_idx) ### I want to keep the mask anyway # if not encoder_padding_mask.any(): # encoder_padding_mask = None encoder_states = [] for i in range(self.encoder_layers): x1, state1 = self.convlayers[i]( x, encoder_padding_mask=encoder_padding_mask) x2, state2 = self.attnlayers[i]( x, encoder_padding_mask=encoder_padding_mask) if state1 is not None: encoder_states.append(state1) if state2 is not None: encoder_states.append(state2) x = x1 * self.ratios[i] + x2 * (1 - self.ratios[i]) # x = 0.5*x1 + 0.5*x2 x, _ = self.fflayers[i](x, encoder_padding_mask=encoder_padding_mask) if self.normalize: x = self.layer_norm(x) return { 'encoder_x': x, 'encoder_padding_mask': encoder_padding_mask, 'encoder_lstm_states': self.get_lstm_states(encoder_states) } def construct_encoder_layer(self, gene): if gene['type'] == 'recurrent': return LSTMEncoderLayer(**gene['param']) elif gene['type'] == 'lightconv': return LightConvEncoderLayer(**gene['param']) elif gene['type'] == 'dynamicconv': return DynamicConvEncoderLayer(**gene['param']) elif gene['type'] == 'self-attention': return AttentionEncoderLayer(**gene['param'], self_attention=True) elif gene['type'] == 'ff': return FFLayer(**gene['param']) else: raise NotImplementedError('Unknown Decoder Gene Type!') def get_lstm_states(self, encoder_states): # only return the state of the topmost lstm layer final_state = None for state in encoder_states: if state is not None and "lstm_hidden_state" in state.keys(): final_state = state["lstm_hidden_state"] return final_state def reorder_encoder_out(self, encoder_out, new_order): if encoder_out['encoder_x'] is not None: encoder_out['encoder_x'] = \ encoder_out['encoder_x'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) if encoder_out['encoder_lstm_states'] is not None: hiddens, cells = encoder_out['encoder_lstm_states'] hiddens = hiddens.index_select(1, new_order) cells = cells.index_select(1, new_order) encoder_out['encoder_lstm_states'] = (hiddens, cells) return encoder_out def max_positions(self): if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions())
class TaLKConvDecoder(FairseqIncrementalDecoder): """ Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs. Default: ``False`` """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True): super().__init__(dictionary) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim output_embed_dim = args.decoder_output_dim padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt( embed_dim) # todo: try with input_embed_dim self.project_in_dim = Linear( input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TaLKConvDecoderLayer(args, no_encoder_attn, kernel_size=args.decoder_kernel_size_list[i]) for i in range(args.decoder_layers) ]) self.adaptive_softmax = None self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \ if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), output_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), output_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim**-0.5) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.decoder_normalize_before and final_norm if self.normalize: self.layer_norm = LayerNorm(embed_dim) self.acts_reg = [] def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the last decoder layer's output of shape `(batch, tgt_len, vocab)` - the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)` """ # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] # decoder layers for layer in self.layers: x, attn = layer( x, encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state) inner_states.append(x) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = F.linear(x, self.embed_out) return x, {'attn': attn, 'inner_states': inner_states} def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim]
class LightConvEncoder(FairseqEncoder): """ LightConv encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`LightConvEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ LightConvEncoderLayer(args, kernel_size=args.encoder_kernel_size_list[i]) for i in range(args.encoder_layers) ]) self.encoder_dynamic_combination = args.encoder_dynamic_combination self.encoder_linear_combination = args.encoder_linear_combination assert not (self.encoder_dynamic_combination and self.encoder_linear_combination) if self.encoder_linear_combination or self.encoder_dynamic_combination: self.weight_ffn = nn.Sequential( nn.Linear(embed_dim, args.encoder_ffn_embed_dim), nn.ReLU(), nn.Linear(args.encoder_ffn_embed_dim, embed_dim), ) if args.encoder_dynamic_combination: self.proj = nn.ModuleList([ nn.Sequential( nn.Linear(embed_dim * args.encoder_layers, embed_dim * 2), nn.ReLU(), nn.Linear(embed_dim * 2, embed_dim), ) for _ in range(args.encoder_layers) ]) if args.encoder_linear_combination: self.weights = nn.ParameterList([ nn.Parameter(torch.randn(1, 1, embed_dim), requires_grad=True) for _ in range(args.encoder_layers) ]) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.encoder_normalize_before if self.normalize: self.layer_norm = LayerNorm(embed_dim) def forward(self, src_tokens, **unused): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None if self.encoder_dynamic_combination or self.encoder_linear_combination: hiddens = [] else: hiddens = None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.encoder_dynamic_combination or self.encoder_linear_combination: hiddens.append(x) if self.encoder_dynamic_combination: assert torch.equal(x, hiddens[-1]) acc_x = torch.zeros_like(x) catted_hidden = torch.cat(hiddens.unbind(), -1) for i, layer in enumerate(self.proj): acc_x += layer(catted_hidden) * hiddens[i] x = acc_x + self.weight_ffn(acc_x) if self.encoder_linear_combination: assert torch.equal(x, hiddens[-1]) acc_x = torch.zeros_like(x) for i, weight in enumerate(self.weights): acc_x += weight * hiddens[i] x = acc_x + self.weight_ffn(acc_x) if self.normalize: x = self.layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions())
class transformer_with_copyDecoder(FairseqIncrementalDecoder): """ transformer_with_copy decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`transformer_with_copyDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). final_norm (bool, optional): apply layer norm to the output of the final decoder layer (default: True). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True): super().__init__(dictionary) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim output_embed_dim = args.decoder_output_dim padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt( embed_dim) # todo: try with input_embed_dim self.project_in_dim = Linear( input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ transformer_with_copyDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.copy_attention = MultiheadOnlyAttention( embed_dim, 1, dropout=0, ) self.copy_or_generate = nn.Sequential(nn.Linear(embed_dim, 1), nn.Sigmoid()) self.adaptive_softmax = None self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \ if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), output_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), output_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim**-0.5) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.decoder_normalize_before and final_norm if self.normalize: self.layer_norm = LayerNorm(embed_dim) def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the last decoder layer's output of shape `(batch, tgt_len, vocab)` - the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)` """ # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) inner_states = [x] # decoder layers for layer in self.layers: x, _ = layer( x, encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, ) inner_states.append(x) if self.normalize: x = self.layer_norm(x) _, copy = self.copy_attention( query=x, key=encoder_out['encoder_out'] if encoder_out is not None else None, value=encoder_out['encoder_out'] if encoder_out is not None else None, key_padding_mask=encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state=incremental_state, static_kv=True, need_weights=True, ) copy_or_generate = self.copy_or_generate(x).transpose(0, 1) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = F.linear(x, self.embed_out) return x, { 'attn': copy, 'inner_states': inner_states, 'copy_or_generate': copy_or_generate } def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" # print('enter normalized.') if 'net_input' in sample.keys(): enc_seq_ids = sample['net_input']['src_tokens'] else: enc_seq_ids = sample['src_tokens'] # wvocab_size = net_output[0].size(2) # batch_size = enc_seq_ids.size(0) # seq_len = enc_seq_ids.size(1) # one_hot = torch.zeros(batch_size, seq_len, wvocab_size).cuda().scatter_(dim=2, index=enc_seq_ids.unsqueeze(-1), value=1) # # copy_probs = torch.matmul(net_output[1]['attn'], one_hot) # final_dist = vocab_dist.scatter_add(1, encoder_batch_extend_vocab, attn_dist) if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: generate = utils.softmax( logits, dim=-1, onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate'] copy = net_output[1]['attn'] * (1 - net_output[1]['copy_or_generate']) enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat( 1, net_output[1]['attn'].size(1), 1) final = generate.scatter_add(2, enc_seq_ids, copy) final = torch.log(final + 1e-15) return final else: generate = utils.log_softmax( logits, dim=-1, onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate'] copy = net_output[1]['attn'] * (1 - net_output[1]['copy_or_generate']) enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat( 1, net_output[1]['attn'].size(1), 1) final = generate.scatter_add(2, enc_seq_ids, copy) return final def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms layer_norm_map = { '0': 'self_attn_layer_norm', '1': 'encoder_attn_layer_norm', '2': 'final_layer_norm' } for old, new in layer_norm_map.items(): for m in ('weight', 'bias'): k = '{}.layers.{}.layer_norms.{}.{}'.format( name, i, old, m) if k in state_dict: state_dict['{}.layers.{}.{}.{}'.format( name, i, new, m)] = state_dict[k] del state_dict[k] if utils.item( state_dict.get('{}.version'.format(name), torch.Tensor( [1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict['{}.version'.format(name)] = torch.Tensor([1]) return state_dict
class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.args = args self.dropout = args.dropout self.bgt_setting = self.args.bgt_setting embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.hidden2mean = nn.Linear(embed_dim, self.args.latent_size, bias=False) if self.bgt_setting == "bgt": self.hidden2logv = nn.Linear(embed_dim, self.args.latent_size, bias=False) self.latent2hidden = nn.Linear(self.args.latent_size, embed_dim, bias=False) def forward(self, src_tokens, src_lengths, generate=False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) # if not encoder_padding_mask.any(): # encoder_padding_mask = None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.layer_norm: x = self.layer_norm(x) #sample z z = None if self.bgt_setting == "bgt" and not generate: z = torch.randn([x.size()[1], self.args.latent_size]) sent_emb, mean, logv = self.get_sentence_embs(x, encoder_padding_mask, z) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T, 'sent_emb': sent_emb, 'mean': mean, 'logv': logv, 'z': z, } def get_sentence_embs(self, encoder_out, encoder_padding_mask, z=None): if not self.args.cpu: mean_pool = torch.where( encoder_padding_mask.unsqueeze(2).cuda(), torch.Tensor([float(0)]).cuda(), encoder_out.transpose(1, 0).float()).type_as(encoder_out) else: mean_pool = torch.where(encoder_padding_mask.unsqueeze(2), torch.Tensor([float(0)]), encoder_out.transpose( 1, 0).float()).type_as(encoder_out) den = encoder_padding_mask.size()[1] - encoder_padding_mask.sum(dim=1) mean_pool = mean_pool.sum(dim=1) / den.float().unsqueeze(1) mean = self.hidden2mean(mean_pool) logv = None if self.bgt_setting == "bgt": logv = self.hidden2logv(mean_pool) if z is not None: std = torch.exp(0.5 * logv) if not self.args.cpu: z = z.cuda() z = z * std + mean sent_emb = self.latent2hidden(z) else: sent_emb = self.latent2hidden(mean) else: sent_emb = mean return sent_emb, mean, logv def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) if encoder_out['sent_emb'] is not None: encoder_out['sent_emb'] = \ encoder_out['sent_emb'].index_select(0, new_order) if encoder_out['mean'] is not None: encoder_out['mean'] = \ encoder_out['mean'].index_select(0, new_order) if encoder_out['logv'] is not None: encoder_out['logv'] = \ encoder_out['logv'].index_select(0, new_order) if encoder_out['z'] is not None: encoder_out['z'] = \ encoder_out['z'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class JointAttentionDecoder(FairseqIncrementalDecoder): """ JointAttention decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`ProtectedTransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding left_pad (bool, optional): whether the input is left-padded. Default: ``False`` """ def __init__(self, args, dictionary, embed_tokens, left_pad=False, final_norm=True): super().__init__(dictionary) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed self.kernel_size_list = args.kernel_size_list input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim output_embed_dim = args.decoder_output_dim padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.project_in_dim = Linear( input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.embed_language = LanguageEmbedding( embed_dim) if args.language_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ ProtectedTransformerDecoderLayer(args, no_encoder_attn=True) for _ in range(args.decoder_layers) ]) self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \ if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None if not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), output_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim**-0.5) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.decoder_normalize_before and final_norm if self.normalize: self.layer_norm = LayerNorm(embed_dim) def forward(self, prev_output_tokens, encoder_out, incremental_state=None): """ Args: input (dict): with prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the last decoder layer's output of shape `(batch, tgt_len, vocab)` - the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)` """ tgt_len = prev_output_tokens.size(1) # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions # language embedding if self.embed_language is not None: lang_emb = self.embed_scale * self.embed_language.view(1, 1, -1) x += lang_emb x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] source = encoder_out['encoder_out'] process_source = incremental_state is None or len( incremental_state) == 0 # extended padding mask source_padding_mask = encoder_out['encoder_padding_mask'] if source_padding_mask is not None: target_padding_mask = source_padding_mask.new_zeros( (source_padding_mask.size(0), tgt_len)) self_attn_padding_mask = torch.cat( (source_padding_mask, target_padding_mask), dim=1) else: self_attn_padding_mask = None # transformer layers for i, layer in enumerate(self.layers): if self.kernel_size_list is not None: target_mask = self.local_mask(x, self.kernel_size_list[i], causal=True, tgt_len=tgt_len) elif incremental_state is None: target_mask = self.buffered_future_mask(x) else: target_mask = None if target_mask is not None: zero_mask = target_mask.new_zeros( (target_mask.size(0), source.size(0))) self_attn_mask = torch.cat((zero_mask, target_mask), dim=1) else: self_attn_mask = None state = incremental_state if process_source: if state is None: state = {} if self.kernel_size_list is not None: source_mask = self.local_mask(source, self.kernel_size_list[i], causal=False) else: source_mask = None source, attn = layer( source, None, None, state, self_attn_mask=source_mask, self_attn_padding_mask=source_padding_mask) inner_states.append(source) x, attn = layer(x, None, None, state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask) inner_states.append(x) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) # project back to size of vocabulary if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = F.linear(x, self.embed_out) pred = x info = {'attn': attn, 'inner_states': inner_states} return pred, info def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): """Cached future mask.""" dim = tensor.size(0) #pylint: disable=access-member-before-definition, attribute-defined-outside-init if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim] def local_mask(self, tensor, kernel_size, causal, tgt_len=None): """Locality constraint mask.""" rows = tensor.size(0) cols = tensor.size(0) if tgt_len is None else tgt_len if causal: if rows == 1: mask = utils.fill_with_neg_inf(tensor.new(1, cols)) mask[0, -kernel_size:] = 0 return mask else: diag_u, diag_l = 1, kernel_size else: diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) // 2) if kernel_size % 2 == 1 \ else (kernel_size // 2, kernel_size // 2 + 1) mask1 = torch.triu(utils.fill_with_neg_inf(tensor.new(rows, cols)), diag_u) mask2 = torch.tril(utils.fill_with_neg_inf(tensor.new(rows, cols)), -diag_l) return mask1 + mask2
class TransformerDecoderPerm(FairseqIncrementalDecoder): """Transformer decoder.""" def __init__(self, args, dictionary, embed_tokens, left_pad=False): super().__init__(dictionary) if not isinstance(args.shorten_decoder_perm, bool): args.shorten_decoder_perm = eval(args.shorten_decoder_perm) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed self.embed_dim = embed_dim = embed_tokens.embedding_dim self.padding_idx = padding_idx = embed_tokens.padding_idx self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( 1024, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) self.layers = nn.ModuleList([]) self.layers.extend([ TransformerDecoderPermLayer(args) for i in range(args.decoder_perm_layers) ]) self.sentence_transformer_arch = args.sentence_transformer_arch self.predict_arch = args.predict_arch self.pointer_net_attn_type = args.pointer_net_attn_type if not self.share_input_output_embed and self.predict_arch == 'seq2seq': self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=embed_dim**-0.5) if self.predict_arch == 'pointer_net': if self.pointer_net_attn_type == 'perceptron': self.pointer_encoder_embed_weight = nn.Parameter( torch.Tensor(embed_dim, embed_dim)) self.pointer_decoder_embed_weight = nn.Parameter( torch.Tensor(embed_dim, embed_dim)) self.mapping_vector = nn.Parameter(torch.Tensor(1, embed_dim)) nn.init.normal_(self.pointer_encoder_embed_weight, mean=0, std=embed_dim**-0.5) nn.init.normal_(self.pointer_decoder_embed_weight, mean=0, std=embed_dim**-0.5) nn.init.normal_(self.mapping_vector, mean=0, std=embed_dim**-0.5) elif self.pointer_net_attn_type == 'general': self.pointer_attn_weight = nn.Parameter( torch.Tensor(args.decoder_embed_dim, args.encoder_embed_dim)) nn.init.normal_(self.pointer_attn_weight, mean=0, std=embed_dim**-0.5) elif self.pointer_net_attn_type == 'dot': pass else: raise RuntimeError( "pointer-net-attn-type doesn't support {} yet !".format( self.pointer_net_attn_type)) def buffered_future_mask(self, tensor): dim = tensor.size(0) if (not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim): self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) return self._future_mask[:dim, :dim] def forward(self, prev_output_tokens, encoder_out, incremental_state=None): # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) x += positions # add the sent embedding to x prev_output_tokens_temp = prev_output_tokens.masked_fill( prev_output_tokens == self.padding_idx, 0) sents_embedding = torch.stack([ encoder_out['encoder_out'][i, prev_output_tokens_temp[:, 1:][i]] for i in range(prev_output_tokens_temp.shape[0]) ]) sents_embedding[prev_output_tokens[:, 1:] == self.padding_idx] = self.embed_tokens( torch.LongTensor([self.padding_idx]).to(x.device)) x[:, 1:] += sents_embedding x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) encoder_out_embedding = encoder_out['encoder_out'].transpose( 0, 1) if self.sentence_transformer_arch == 'bert' else encoder_out[ 'encoder_out'] # decoder layers self_attn_mask = self.buffered_future_mask(x) for layer in self.layers: x, attn = layer( x, encoder_out_embedding, encoder_out['encoder_padding_mask'], incremental_state, self_attn_mask, ) # T x B x C -> B x T x C x = x.transpose(0, 1) # project back to size of vocabulary if self.predict_arch == 'seq2seq': if self.share_input_output_embed: out = F.linear(x, self.embed_tokens.weight) else: out = F.linear(x, self.embed_out) elif self.predict_arch == 'pointer_net': bsz = prev_output_tokens.shape[0] encoder_embedding_querry = torch.cat([ encoder_out['encoder_out'], self.embed_tokens( torch.LongTensor([self.dictionary.eos()]).to( x.device)).expand([bsz, 1, self.embed_dim]) ], dim=1) if self.pointer_net_attn_type == 'perceptron': temp_embedding = F.linear( encoder_embedding_querry, self.pointer_encoder_embed_weight ).unsqueeze(dim=1) + F.linear( x, self.pointer_decoder_embed_weight).unsqueeze(dim=2) temp_embedding = F.tanh(temp_embedding) out = F.linear(temp_embedding, self.mapping_vector).squeeze(dim=-1) elif self.pointer_net_attn_type == 'general': out = x.matmul(self.pointer_attn_weight).bmm( encoder_embedding_querry.transpose(-1, -2)) elif self.pointer_net_attn_type == 'dot': out = x.bmm(encoder_embedding_querry.transpose(-1, -2)) return out def max_positions(self): """Maximum output length supported by the decoder.""" return self.embed_positions.max_positions() def upgrade_state_dict(self, state_dict): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if 'decoder_perm.embed_positions.weights' in state_dict: del state_dict['decoder_perm.embed_positions.weights'] # if 'decoder_perm.embed_positions._float_tensor' in state_dict: # del state_dict['decoder_perm.embed_positions._float_tensor'] state_dict[ 'decoder_perm.embed_positions._float_tensor'] = torch.FloatTensor( 1) ''' in_proj_weight -> q_proj.weight, k_proj.weight, v_proj.weight in_proj_bias -> q_proj.bias, k_proj.bias, v_proj.bias ''' def transform_params(idx, suffix): in_proj_ = state_dict[ 'decoder_perm.layers.{}.self_attn.in_proj_{}'.format( idx, suffix)] del state_dict[ 'decoder_perm.layers.{}.self_attn.in_proj_{}'.format( idx, suffix)] state_dict['decoder_perm.layers.{}.self_attn.q_proj.{}'.format(idx, suffix)], state_dict['decoder_perm.layers.{}.self_attn.k_proj.{}'.format(idx, suffix)],\ state_dict['decoder_perm.layers.{}.self_attn.v_proj.{}'.format(idx, suffix)] = in_proj_.chunk(3, dim=0) if 'decoder_perm.layers.0.self_attn.in_proj_weight' in state_dict: for idx in range(len(self.layers)): transform_params(idx, 'weight') if 'decoder_perm.layers.0.self_attn.in_proj_bias' in state_dict: for idx in range(len(self.layers)): transform_params(idx, 'bias') return state_dict
class TransformerEncoder(nn.Module): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, vocab, embed_tokens): super().__init__() self.vocab = vocab self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args, i) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None def forward(self, src_tokens, data_holder=None, mask=None, encoder_mode="soft", encoder_temperature=-1, need_weights=False, **unused): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions embedding = self.embed_tokens(src_tokens) if data_holder is not None: if data_holder.permute_embed is not None: bsz, _, dim = embedding.shape for i in range(bsz): perm = torch.randperm(dim) embedding[i, data_holder.permute_embed] = embedding[ i, data_holder.permute_embed, perm] if data_holder.keep_grads: data_holder.embedding = embedding data_holder.embedding.retain_grad() x = self.embed_scale * embedding x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers attn_data_list = [] for layer in self.layers: x, attn_data = layer(x, data_holder=data_holder, encoder_padding_mask=encoder_padding_mask, attn_mask=mask, encoder_mode=encoder_mode, encoder_temperature=encoder_temperature, need_weights=need_weights) attn_data_list.append(attn_data) if self.layer_norm: x = self.layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_attn_data_list': attn_data_list } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions())
class TransformerYmaskEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) if args.encoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None self.embed_lengths = nn.Embedding(args.max_target_positions, embed_dim) nn.init.normal_(self.embed_lengths.weight, mean=0, std=0.02) def forward(self, src_tokens, src_lengths, **unused): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) # add length prediction part len_tokens = self.embed_lengths( src_tokens.new(src_tokens.size(0), 1).fill_(0)) x = torch.cat([len_tokens, x], dim=1) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) # to keep consistent with x encoder_padding_mask = torch.cat([ encoder_padding_mask.new(src_tokens.size(0), 1).fill_(0), encoder_padding_mask ], dim=1) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.layer_norm: x = self.layer_norm(x) predicted_lengths_logits = torch.matmul( x[0, :, :], self.embed_lengths.weight.transpose(0, 1)).float() predicted_lengths_logits[:, 0] += float( '-inf') # Cannot predict the len_token predicted_lengths = F.log_softmax(predicted_lengths_logits, dim=-1) x = x[1:, :, :] if encoder_padding_mask is not None: encoder_padding_mask = encoder_padding_mask[:, 1:] return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T 'predicted_lengths': predicted_lengths, # B x L } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) if encoder_out['predicted_lengths'] is not None: encoder_out['predicted_lengths'] = \ encoder_out['predicted_lengths'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i)) version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout self.decoder_layerdrop = args.decoder_layerdrop self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.output_embed_dim = args.decoder_output_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( embed_dim) self.project_in_dim = Linear( input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.cross_self_attention = getattr(args, 'cross_self_attention', False) self.layer_wise_attention = getattr(args, 'layer_wise_attention', False) self.layers = nn.ModuleList([]) self.layers.extend([ TransformerDecoderLayer(args, no_encoder_attn, layer_id=i) for i in range(args.decoder_layers) ]) self.adaptive_softmax = None self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \ if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), self.output_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5) if args.decoder_normalize_before and not getattr( args, 'no_decoder_final_norm', False): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None if getattr(args, 'layernorm_embedding', False): self.layernorm_embedding = LayerNorm(embed_dim) else: self.layernorm_embedding = None def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, features_only=False, **extra_args): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, **extra_args) if not features_only: x = self.output_layer(x) return x, extra def extract_features( self, prev_output_tokens, encoder_out=None, incremental_state=None, full_context_alignment=False, alignment_layer=None, alignment_heads=None, **unused, ): """ Similar to *forward* but only return features. Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., EMNLP 2019). Args: full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). alignment_layer (int, optional): return mean alignment over heads at this layer (default: last layer). alignment_heads (int, optional): only average alignment over this many heads (default: all heads). Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ if alignment_layer is None: alignment_layer = len(self.layers) - 1 # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions if self.layernorm_embedding: x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) self_attn_padding_mask = None if self.cross_self_attention or prev_output_tokens.eq( self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # decoder layers attn = None inner_states = [x] for idx, layer in enumerate(self.layers): encoder_state = None if encoder_out is not None: if self.layer_wise_attention: encoder_state = encoder_out.encoder_states[idx] else: encoder_state = encoder_out.encoder_out if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) if not self.training or (dropout_probability > self.decoder_layerdrop): x, layer_attn = layer( x, encoder_state, encoder_out.encoder_padding_mask if encoder_out is not None else None, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=(idx == alignment_layer), need_head_weights=(idx == alignment_layer), ) inner_states.append(x) if layer_attn is not None and idx == alignment_layer: attn = layer_attn.float() if attn is not None: if alignment_heads is not None: attn = attn[:alignment_heads] # average probabilities over heads attn = attn.mean(dim=0) if self.layer_norm: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) return x, {'attn': attn, 'inner_states': inner_states} def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: return F.linear(features, self.embed_tokens.weight) else: return F.linear(features, self.embed_out) else: return features def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if (not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim): self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms layer_norm_map = { '0': 'self_attn_layer_norm', '1': 'encoder_attn_layer_norm', '2': 'final_layer_norm' } for old, new in layer_norm_map.items(): for m in ('weight', 'bias'): k = '{}.layers.{}.layer_norms.{}.{}'.format( name, i, old, m) if k in state_dict: state_dict['{}.layers.{}.{}.{}'.format( name, i, new, m)] = state_dict[k] del state_dict[k] version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class GraphTransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, args, dictionary, embed_tokens, embed_edges): super().__init__(dictionary) self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.embed_dim = embed_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens self.embed_edges = embed_edges self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.l1_gate = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.SELU(), nn.Dropout(self.dropout)) self.l2_gate = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.SELU(), nn.Dropout(self.dropout)) # self.e1_gate = nn.Linear(embed_dim*3, embed_dim) # self.e2_gate = nn.Linear(embed_dim * 3, embed_dim) self.e1_gate = nn.Sequential(nn.Linear(embed_dim * 3, embed_dim), nn.SELU(), nn.Dropout(self.dropout)) self.e2_gate = nn.Sequential(nn.Linear(embed_dim * 3, embed_dim), nn.SELU(), nn.Dropout(self.dropout)) self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) self.graph_layers = nn.ModuleList([]) self.graph_layers.extend([ GraphTransformerEncoderLayer(args) for i in range(args.graph_layers) ]) self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim, batch_first=False, num_layers=1, bidirectional=True) self.register_buffer('version', torch.Tensor([2])) self.normalize = args.encoder_normalize_before if self.normalize: self.layer_norm = LayerNorm(embed_dim) self.graph_layer_norm = LayerNorm(embed_dim) def forward(self, src_tokens, src_lengths, enc_edge_ids, enc_edge_links1, enc_edge_links2, graph_mask, graph_mask_rev): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None enc_edge_padding_mask = enc_edge_ids.eq(self.padding_idx) # embed edges enc_edge_emb = self.embed_scale * self.embed_edges(enc_edge_ids) enc_edge_emb = F.dropout(enc_edge_emb, p=self.dropout, training=self.training) # B x L x C -> L x B x C enc_edge_emb = enc_edge_emb.transpose(0, 1) # B x L -> L x B x C idx_pairs1 = enc_edge_links1.unsqueeze(-1).repeat(1, 1, self.embed_dim) idx_pairs2 = enc_edge_links2.unsqueeze(-1).repeat(1, 1, self.embed_dim) idx_pairs1 = idx_pairs1.transpose(0, 1) idx_pairs2 = idx_pairs2.transpose(0, 1) # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) if self.normalize: x = self.layer_norm(x) # graph layers enc_states = [] enc_states.append(x) for layer in self.graph_layers: x = layer(x, self.l1_gate, self.l2_gate, self.e1_gate, self.e2_gate, idx_pairs1, idx_pairs2, encoder_padding_mask, enc_edge_padding_mask, graph_mask, graph_mask_rev, enc_edge_emb) enc_states.append(x) enc_states = torch.stack(enc_states, dim=3) enc_states = torch.transpose(torch.transpose(enc_states, 2, 3), 0, 2) bsz = enc_states.size(1) seq_len = enc_states.size(2) enc_states = enc_states.reshape( len(self.graph_layers) + 1, bsz * seq_len, self.embed_dim) outputs, state = self.rnn(enc_states) x = state[0][::2] + state[1][::2] x = x.reshape(bsz, -1, self.embed_dim) x = torch.transpose(x, 0, 1) if self.normalize: x = self.graph_layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if encoder_out['encoder_out'] is not None: encoder_out['encoder_out'] = \ encoder_out['encoder_out'].index_select(1, new_order) if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format( name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}") version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class TransformerDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = args.decoder_embed_dim self.output_embed_dim = args.decoder_output_dim padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.adaptive_softmax = None self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \ if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, tie_proj=args.tie_adaptive_proj, ) elif not self.share_input_output_embed: self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None #---------------------------- self.save_attn = args.save_attn self.save_attn_path = args.save_attn_path def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state) x = self.output_layer(x) return x, extra def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] # decoder layers for layer in self.layers: x, attn = layer( x, encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, ) inner_states.append(x) if self.layer_norm: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) #--------------------------------------------------- if self.save_attn == True: save_attn(attn,self.save_attn_path,self.dictionary.string(prev_output_tokens), encoder_out['src_tokens']) return x, {'attn': attn, 'inner_states': inner_states} def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: return F.linear(features, self.embed_tokens.weight) else: return F.linear(features, self.embed_out) else: return features def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) for i in range(len(self.layers)): # update layer norms layer_norm_map = { '0': 'self_attn_layer_norm', '1': 'encoder_attn_layer_norm', '2': 'final_layer_norm' } for old, new in layer_norm_map.items(): for m in ('weight', 'bias'): k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m) if k in state_dict: state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k] del state_dict[k] version_key = '{}.version'.format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict
class SentTransformerDecoder(nn.Module): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). final_norm (bool, optional): apply layer norm to the output of the final decoder layer (default: True). """ def __init__(self, args, no_encoder_attn=False, final_norm=True): super(SentTransformerDecoder, self).__init__() self.dropout = args.dropout embed_dim = args.decoder_embed_dim self.max_target_positions = args.max_target_positions self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx=0, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.layers = nn.ModuleList([]) self.layers.extend([ TransformerDecoderLayer(args=args, no_encoder_attn=no_encoder_attn) for i in range(args.decoder_layers) ]) self.normalize = args.decoder_normalize_before and final_norm if self.normalize: self.layer_norm = LayerNorm(embed_dim) def forward(self, prev_output_tokens, prev_output_rep, encoder_out=None, incremental_state=None, **unused): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing prev_output_rep encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, dim)` - a dictionary with any model-specific outputs """ x, extra = self.extract_features(prev_output_tokens, prev_output_rep, encoder_out, incremental_state) # x = self.output_layer(x) return x, extra def extract_features(self, prev_output_tokens, prev_output_rep, encoder_out=None, incremental_state=None): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ # embed positions # (batch, tgt_len) positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: # incre decoding 时就取最后一个 前面的都已经缓存了 prev_output_tokens = prev_output_tokens[:, -1:] prev_output_rep = prev_output_rep[:, -1, :] if positions is not None: positions = positions[:, -1:] # embed tokens and positions # x = self.embed_scale * self.embed_tokens(prev_output_tokens) x = prev_output_rep # if self.project_in_dim is not None: # x = self.project_in_dim(x) if positions is not None: # print('len x %s |len positions %s'%(x.size(),positions.size())) x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] # decoder layers for layer in self.layers: x, attn = layer( x, encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, ) inner_states.append(x) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) return x, {'attn': attn, 'inner_states': inner_states} def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim]