def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn) self.dictionary = dictionary self.bos = dictionary.bos() self.unk = dictionary.unk() self.eos = dictionary.eos() self.sampling_for_deletion = getattr(args, "sampling_for_deletion", False) self.embed_mask_ins = Embedding(256, self.output_embed_dim * 2, None) self.embed_word_del = Embedding(2, self.output_embed_dim, None) # del_word, ins_mask, ins_word self.early_exit = [int(i) for i in args.early_exit.split(',')] assert len(self.early_exit) == 3 # copy layers for mask-predict/deletion self.layers_msk = None if getattr(args, "no_share_maskpredictor", False): self.layers_msk = nn.ModuleList([ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(self.early_exit[1]) ]) self.layers_del = None if getattr(args, "no_share_discriminator", False): self.layers_del = nn.ModuleList([ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(self.early_exit[0]) ]) if getattr(args, "share_discriminator_maskpredictor", False): assert getattr(args, "no_share_discriminator", False), "must set saperate discriminator" self.layers_msk = self.layers_del
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True): super().__init__(args, dictionary, embed_tokens, no_encoder_attn, left_pad, final_norm) self.share_directions = False if not hasattr( args, 'decoder_share_directions') else args.decoder_share_directions if self.share_directions: self.layers_bw = None else: self.layers_bw = nn.ModuleList([]) self.layers_bw.extend([ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.padding_idx = dictionary.pad() self.embedding_dim = self.embed_dim = args.decoder_embed_dim self.n_hidden_states = args.decoder_layers if args.decoder_use_biattention: self.n_hidden_states += 1 self.use_biattention = args.decoder_use_biattention if self.use_biattention: self.biblock = BiTransformerDecoderLayer(args) else: self.biblock = None self.pad_fw = nn.Parameter(torch.randn(1, 1, args.decoder_embed_dim), requires_grad=True) self.pad_bw = nn.Parameter(torch.randn(1, 1, args.decoder_embed_dim), requires_grad=True)
def __init__(self, args, dictionary, embed_tokens, classification_head=None): super().__init__(dictionary) self.onnx_trace = False self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed self.embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.self_target = args.self_target self.future_target = args.future_target self.past_target = args.past_target self.char_inputs = args.char_inputs self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(self.embed_dim) self.embed_positions = PositionalEmbedding( args.max_target_positions, self.embed_dim, self.padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None self.forward_layers = nn.ModuleList([ TransformerDecoderLayer( args, no_encoder_attn=True, add_bias_kv=not args.no_bias_kv, add_zero_attn=args.no_bias_kv, ) for _ in range(args.decoder_layers) ]) self.backward_layers = nn.ModuleList([ TransformerDecoderLayer( args, no_encoder_attn=True, add_bias_kv=not args.no_bias_kv, add_zero_attn=args.no_bias_kv, ) for _ in range(args.decoder_layers) ]) self.full_attn_layer = None self.full_linear_layer = None if self.self_target: if args.linear_final_layer: self.full_linear_layer = Linear(self.embed_dim * 2, self.embed_dim, args.linear_final_layer_bias) else: self.full_attn_layer = BidirectionalTransformerDecoderLayer( args) self.load_softmax = not getattr(args, 'remove_head', False) self.embed_out = None self.adaptive_softmax = None self.classification_head = classification_head if self.load_softmax: if args.adaptive_softmax_cutoff is not None: self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), args.decoder_embed_dim, options.eval_str_list(args.adaptive_softmax_cutoff, type=int), dropout=args.adaptive_softmax_dropout, ) elif not self.share_input_output_embed: self.embed_out = nn.Parameter( torch.Tensor(len(dictionary), self.embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=self.embed_dim**-0.5)