def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False,
                 add_suphead=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention',
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )

        if add_suphead:
            suphead_num = args.decoder_attention_heads
            self.suphead = MultiheadAttention(
                self.embed_dim,
                suphead_num,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
        else:
            self.suphead = None

        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False
    def __init__(
        self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
    ):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        embed_dim = self.embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention", False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
#        self.dropout = [0.05, 0.1, 0.25, 0.3]
#        self.dropout = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.3]
        self.dropout = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu")
        )
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
#        self.self_attn_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])
        self.self_attn_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim])

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])

#        self.fc1 = SLinear([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim], 
#                           [int(args.decoder_ffn_embed_dim / 4), int(args.decoder_ffn_embed_dim * 2 / 4),int(args.decoder_ffn_embed_dim * 3 / 4), args.decoder_ffn_embed_dim])
#        self.fc2 = SLinear([int(args.decoder_ffn_embed_dim / 4), int(args.decoder_ffn_embed_dim * 2 / 4), int(args.decoder_ffn_embed_dim * 3 / 4), args.decoder_ffn_embed_dim], 
#                           [int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])

#        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
#        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

#        self.final_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])
            self.encoder_attn_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim])


        self.fc1 = SLinear(embed_dim, args.encoder_ffn_embed_dim)
        self.fc2 = SLinear(args.encoder_ffn_embed_dim, embed_dim)

        self.final_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim])


        self.final_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim])

        self.linear_list = [int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim]

        self.ffn_list = [int(args.encoder_ffn_embed_dim * 4 / 16), int(args.encoder_ffn_embed_dim * 5 / 16), int(args.encoder_ffn_embed_dim * 6 / 16), int(args.encoder_ffn_embed_dim * 7 / 16), int(args.encoder_ffn_embed_dim * 8 / 16), int(args.encoder_ffn_embed_dim * 9 / 16), int(args.encoder_ffn_embed_dim * 10 / 16), int(args.encoder_ffn_embed_dim * 11 / 16), int(args.encoder_ffn_embed_dim * 12 / 16), int(args.encoder_ffn_embed_dim * 13 / 16),  int(args.encoder_ffn_embed_dim * 14 / 16), int(args.encoder_ffn_embed_dim * 15 / 16), args.encoder_ffn_embed_dim]



        self.need_attn = True

        self.onnx_trace = False
Exemple #3
0
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__)
        self.quant_noise = getattr(args, "quant_noise_pq", 0)
        self.quant_noise_block_size = getattr(args,
                                              "quant_noise_pq_block_size", 8)

        self.cross_self_attention = getattr(args, "cross_self_attention",
                                            False)

        self.self_attn = self.build_self_attention(
            self.embed_dim,
            args,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
        )

        self.activation_fn = utils.get_activation_fn(
            activation=str(args.activation_fn) if getattr(
                args, "activation_fn", None) is not None else "relu")
        activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
        if activation_dropout_p == 0:
            # for backwards compatibility with models that use args.relu_dropout
            activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
        self.activation_dropout_module = FairseqDropout(
            float(activation_dropout_p), module_name=self.__class__.__name__)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = self.build_encoder_attention(
                self.embed_dim, args)
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = self.build_fc1(
            self.embed_dim,
            args.decoder_ffn_embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )
        self.fc2 = self.build_fc2(
            args.decoder_ffn_embed_dim,
            self.embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False
    def __init__(self,
                 cfg,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = cfg.decoder.embed_dim
        self.dropout_module = FairseqDropout(
            cfg.dropout, module_name=self.__class__.__name__)
        self.quant_noise = cfg.quant_noise.pq
        self.quant_noise_block_size = cfg.quant_noise.pq_block_size

        self.cross_self_attention = cfg.cross_self_attention

        self.self_attn = self.build_self_attention(
            self.embed_dim,
            cfg,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
        )
        self.attn_ln = LayerNorm(self.embed_dim) if utils.safe_getattr(
            cfg, 'scale_attn', False) else None
        self.nh = self.self_attn.num_heads
        self.head_dim = self.self_attn.head_dim
        scale_heads = utils.safe_getattr(cfg, 'scale_heads', False)
        self.c_attn = nn.Parameter(torch.ones(
            (self.nh, )), requires_grad=True) if scale_heads else None

        self.activation_fn = utils.get_activation_fn(
            activation=cfg.activation_fn)
        activation_dropout_p = cfg.activation_dropout
        if activation_dropout_p == 0:
            # for backwards compatibility with models that use cfg.relu_dropout
            activation_dropout_p = cfg.relu_dropout or 0
        self.activation_dropout_module = FairseqDropout(
            float(activation_dropout_p), module_name=self.__class__.__name__)
        self.normalize_before = cfg.decoder.normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim,
                                              export=cfg.export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = self.build_encoder_attention(
                self.embed_dim, cfg)
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=cfg.export)

        self.ffn_layernorm = LayerNorm(
            cfg.decoder.ffn_embed_dim) if utils.safe_getattr(
                cfg, 'scale_fc', False) else None
        self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ),
                                    requires_grad=True) if utils.safe_getattr(
                                        cfg, 'scale_resids', False) else None

        self.fc1 = self.build_fc1(
            self.embed_dim,
            cfg.decoder.ffn_embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )
        self.fc2 = self.build_fc2(
            cfg.decoder.ffn_embed_dim,
            self.embed_dim,
            self.quant_noise,
            self.quant_noise_block_size,
        )

        self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
        self.need_attn = True

        self.onnx_trace = False
Exemple #5
0
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False,
                 LayerNum=None):
        super().__init__()

        global tmp_file

        self.args = args
        if not hasattr(self.args, 'mixed_precision'):
            self.args.mixed_precision = False
        if not hasattr(self.args, 'plot_variance'):
            self.args.plot_variance = False
        if not hasattr(self.args, 'plot_gradient'):
            self.args.plot_gradient = False

        self.normalize_before = args.decoder_normalize_before
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention',
                                            False)

        self.layer_num = LayerNum
        if 'adaptive' in args.init_type:
            assert not self.normalize_before

            self.self_attn = MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=not self.cross_self_attention)

            assert not no_encoder_attn
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True)

            self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
            self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

            if 'adaptive-profiling' == args.init_type:
                if not tmp_file:
                    tmp_file = open('profile.ratio.init', 'w')
                self.self_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.encoder_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
            else:
                if not tmp_file:
                    tmp_file = open('profile.ratio.init', 'r')

                layer_iter, next_value = [
                    float(tup) for tup in tmp_file.readline().split()
                ]
                print('layer_num: {}, layer_iter: {}'.format(
                    self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 1
                print('decoder self ratio: {}'.format(next_value))
                self.self_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.self_ratio_change.data.fill_(next_value)

                layer_iter, next_value = [
                    float(tup) for tup in tmp_file.readline().split()
                ]
                print('layer_num: {}, layer_iter: {}'.format(
                    self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 2
                print('decoder en ratio: {}'.format(next_value))
                self.encoder_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.encoder_ratio_change.data.fill_(next_value)

                layer_iter, next_value = [
                    float(tup) for tup in tmp_file.readline().split()
                ]
                print('layer_num: {}, layer_iter: {}'.format(
                    self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 3
                print('decoder ffn ratio: {}'.format(next_value))
                self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
                self.fc_ratio_change.data.fill_(next_value)

            export = getattr(args, 'char_inputs', False)
            self.self_attn_layer_norm = LayerNorm(self.embed_dim,
                                                  export=export)
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)
            self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        else:
            self.self_attn = MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=not self.cross_self_attention)

            assert not no_encoder_attn
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True)

            self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
            self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
            if args.init_type == 'looklinear':
                self.fc1.weight.data[int(args.decoder_ffn_embed_dim /
                                         2):, :] = -self.fc1.weight.data[
                                             0:int(args.decoder_ffn_embed_dim /
                                                   2), :]
                self.fc2.weight.data[:,
                                     int(args.decoder_ffn_embed_dim /
                                         2):] = -self.fc2.weight.data[:, 0:int(
                                             args.decoder_ffn_embed_dim / 2)]

            export = getattr(args, 'char_inputs', False)

            if args.init_type != 'rezero':
                self.self_attn_layer_norm = LayerNorm(self.embed_dim,
                                                      export=export)
                if no_encoder_attn:
                    self.encoder_attn = None
                    self.encoder_attn_layer_norm = None
                else:
                    self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                             export=export)
                self.final_layer_norm = LayerNorm(self.embed_dim,
                                                  export=export)
            else:
                self.self_attn_layer_norm = None
                self.encoder_attn_layer_norm = None
                self.final_layer_norm = None

            if 'rezero' in args.init_type:
                self.rezero_weight = nn.Parameter(torch.Tensor([0]))
            else:
                assert args.init_type == 'default'
                self.rezero_weight = None

        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            self.activation_dropout = getattr(args, 'relu_dropout', 0)

        self.need_attn = True

        self.onnx_trace = False

        if args.fp16:
            self.in_type = torch.half
        else:
            self.in_type = torch.float
Exemple #6
0
    def __init__(self, args, LayerNum=None):
        super().__init__()
        global tmp_file

        self.args = args
        if not hasattr(self.args, 'mixed_precision'):
            self.args.mixed_precision = False
        if not hasattr(self.args, 'plot_variance'):
            self.args.plot_variance = False
        if not hasattr(self.args, 'plot_gradient'):
            self.args.plot_gradient = False
        if not hasattr(self.args, 'plot_stability'):
            self.args.plot_stability = False

        self.normalize_before = args.encoder_normalize_before
        self.embed_dim = args.encoder_embed_dim

        self.layer_num = LayerNum
        # if LayerNum is not None and not self.normalize_before:
        if 'adaptive' in args.init_type:
            assert not self.normalize_before

            self.self_attn = MultiheadAttention(self.embed_dim,
                                                args.encoder_attention_heads,
                                                dropout=args.attention_dropout,
                                                self_attention=True)

            self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
            self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)

            if 'adaptive-profiling' == args.init_type:
                if not tmp_file:
                    tmp_file = open('profile.ratio.init', 'w')
                self.attention_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
            else:
                if not tmp_file:
                    tmp_file = open('profile.ratio.init', 'r')

                layer_iter, next_value = [
                    float(tup) for tup in tmp_file.readline().split()
                ]
                print('layer_num: {}, layer_iter: {}'.format(
                    self.layer_num, layer_iter))
                assert layer_iter == 2 * self.layer_num + 1
                print('encoder attn ratio: {}'.format(next_value))
                self.attention_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.attention_ratio_change.data.fill_(next_value)

                layer_iter, next_value = [
                    float(tup) for tup in tmp_file.readline().split()
                ]
                print('layer_num: {}, layer_iter: {}'.format(
                    self.layer_num, layer_iter))
                assert layer_iter == 2 * self.layer_num + 2
                print('encoder ffn ratio: {}'.format(next_value))
                self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
                self.fc_ratio_change.data.fill_(next_value)

            self.self_attn_layer_norm = LayerNorm(self.embed_dim)
            self.final_layer_norm = LayerNorm(self.embed_dim)

        else:

            self.self_attn = MultiheadAttention(self.embed_dim,
                                                args.encoder_attention_heads,
                                                dropout=args.attention_dropout,
                                                self_attention=True)

            self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
            self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
            if args.init_type == 'looklinear':
                self.fc1.weight.data[int(args.encoder_ffn_embed_dim /
                                         2):, :] = -self.fc1.weight.data[
                                             0:int(args.encoder_ffn_embed_dim /
                                                   2), :]
                self.fc2.weight.data[:,
                                     int(args.encoder_ffn_embed_dim /
                                         2):] = -self.fc2.weight.data[:, 0:int(
                                             args.encoder_ffn_embed_dim / 2)]

            if args.init_type != 'rezero':
                self.self_attn_layer_norm = LayerNorm(self.embed_dim)
                self.final_layer_norm = LayerNorm(self.embed_dim)
            else:
                self.self_attn_layer_norm = None
                self.final_layer_norm = None

            if 'rezero' in args.init_type:
                self.rezero_weight = nn.Parameter(torch.Tensor([0]))
            else:
                assert args.init_type == 'default'
                self.rezero_weight = None

        if self.args.plot_stability:
            self.x0_hat = None
            self.x1_hat = None
            if self.layer_num == self.args.encoder_layers - 1:
                self.x_final = None

        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            self.activation_dropout = getattr(args, 'relu_dropout', 0)

        if args.fp16:
            self.in_type = torch.half
        else:
            self.in_type = torch.float
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.num_branches = args.decoder_branches
        self.num_pffn_branches = args.decoder_pffn_branches
        self.branch_dropout = args.branch_dropout
        self.pffn_branch_dropout = args.pffn_branch_dropout
        self.enable_head_dropout = args.enable_head_dropout
        self.join_pffn = args.join_pffn
        self.self_attn_branches = nn.ModuleList([
            MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=True,
                head_dropout=self.branch_dropout
                if self.enable_head_dropout else None,
            ) for _ in range(self.num_branches)
        ])
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn_branches = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn_branches = nn.ModuleList([
                MultiheadAttention(
                    self.embed_dim,
                    args.decoder_attention_heads,
                    kdim=getattr(args, 'encoder_embed_dim', None),
                    vdim=getattr(args, 'encoder_embed_dim', None),
                    dropout=args.attention_dropout,
                    encoder_decoder_attention=True,
                    head_dropout=self.branch_dropout
                    if self.enable_head_dropout else None,
                ) for _ in range(self.num_branches)
            ])
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1_branches = nn.ModuleList([
            Linear(self.embed_dim, args.decoder_ffn_embed_dim)
            for _ in range(self.num_pffn_branches)
        ])
        self.fc2_branches = nn.ModuleList([
            Linear(args.decoder_ffn_embed_dim, self.embed_dim)
            for _ in range(self.num_pffn_branches)
        ])

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False
Exemple #8
0
 def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout):
     super().__init__()
     self.dense = nn.Linear(input_dim, inner_dim)
     self.activation_fn = utils.get_activation_fn(activation_fn)
     self.dropout = nn.Dropout(p=pooler_dropout)
     self.out_proj = nn.Linear(inner_dim, num_classes)