コード例 #1
0
    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super().__init__(args,
                         dictionary,
                         embed_tokens,
                         no_encoder_attn=no_encoder_attn)
        self.dictionary = dictionary
        self.bos = dictionary.bos()
        self.unk = dictionary.unk()
        self.eos = dictionary.eos()
        self.sampling_for_deletion = getattr(args, "sampling_for_deletion",
                                             False)
        self.embed_mask_ins = Embedding(256, self.output_embed_dim * 2, None)
        self.embed_word_del = Embedding(2, self.output_embed_dim, None)

        # del_word, ins_mask, ins_word
        self.early_exit = [int(i) for i in args.early_exit.split(',')]
        assert len(self.early_exit) == 3

        # copy layers for mask-predict/deletion
        self.layers_msk = None
        if getattr(args, "no_share_maskpredictor", False):
            self.layers_msk = nn.ModuleList([
                TransformerDecoderLayer(args, no_encoder_attn)
                for _ in range(self.early_exit[1])
            ])
        self.layers_del = None
        if getattr(args, "no_share_discriminator", False):
            self.layers_del = nn.ModuleList([
                TransformerDecoderLayer(args, no_encoder_attn)
                for _ in range(self.early_exit[0])
            ])

        if getattr(args, "share_discriminator_maskpredictor", False):
            assert getattr(args, "no_share_discriminator",
                           False), "must set saperate discriminator"
            self.layers_msk = self.layers_del
コード例 #2
0
ファイル: bitransformer.py プロジェクト: mbevila/qbert
    def __init__(self,
                 args,
                 dictionary,
                 embed_tokens,
                 no_encoder_attn=False,
                 left_pad=False,
                 final_norm=True):
        super().__init__(args, dictionary, embed_tokens, no_encoder_attn,
                         left_pad, final_norm)

        self.share_directions = False if not hasattr(
            args,
            'decoder_share_directions') else args.decoder_share_directions
        if self.share_directions:
            self.layers_bw = None
        else:
            self.layers_bw = nn.ModuleList([])
            self.layers_bw.extend([
                TransformerDecoderLayer(args, no_encoder_attn)
                for _ in range(args.decoder_layers)
            ])

        self.padding_idx = dictionary.pad()

        self.embedding_dim = self.embed_dim = args.decoder_embed_dim
        self.n_hidden_states = args.decoder_layers
        if args.decoder_use_biattention:
            self.n_hidden_states += 1

        self.use_biattention = args.decoder_use_biattention
        if self.use_biattention:
            self.biblock = BiTransformerDecoderLayer(args)
        else:
            self.biblock = None

        self.pad_fw = nn.Parameter(torch.randn(1, 1, args.decoder_embed_dim),
                                   requires_grad=True)
        self.pad_bw = nn.Parameter(torch.randn(1, 1, args.decoder_embed_dim),
                                   requires_grad=True)
コード例 #3
0
    def __init__(self,
                 args,
                 dictionary,
                 embed_tokens,
                 classification_head=None):
        super().__init__(dictionary)
        self.onnx_trace = False
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        self.embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.self_target = args.self_target
        self.future_target = args.future_target
        self.past_target = args.past_target
        self.char_inputs = args.char_inputs

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(self.embed_dim)

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            self.embed_dim,
            self.padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.forward_layers = nn.ModuleList([
            TransformerDecoderLayer(
                args,
                no_encoder_attn=True,
                add_bias_kv=not args.no_bias_kv,
                add_zero_attn=args.no_bias_kv,
            ) for _ in range(args.decoder_layers)
        ])
        self.backward_layers = nn.ModuleList([
            TransformerDecoderLayer(
                args,
                no_encoder_attn=True,
                add_bias_kv=not args.no_bias_kv,
                add_zero_attn=args.no_bias_kv,
            ) for _ in range(args.decoder_layers)
        ])

        self.full_attn_layer = None
        self.full_linear_layer = None

        if self.self_target:
            if args.linear_final_layer:
                self.full_linear_layer = Linear(self.embed_dim * 2,
                                                self.embed_dim,
                                                args.linear_final_layer_bias)
            else:
                self.full_attn_layer = BidirectionalTransformerDecoderLayer(
                    args)

        self.load_softmax = not getattr(args, 'remove_head', False)
        self.embed_out = None
        self.adaptive_softmax = None
        self.classification_head = classification_head

        if self.load_softmax:
            if args.adaptive_softmax_cutoff is not None:
                self.adaptive_softmax = AdaptiveSoftmax(
                    len(dictionary),
                    args.decoder_embed_dim,
                    options.eval_str_list(args.adaptive_softmax_cutoff,
                                          type=int),
                    dropout=args.adaptive_softmax_dropout,
                )
            elif not self.share_input_output_embed:
                self.embed_out = nn.Parameter(
                    torch.Tensor(len(dictionary), self.embed_dim))
                nn.init.normal_(self.embed_out,
                                mean=0,
                                std=self.embed_dim**-0.5)