def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            self.embed_dim,
            args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before
        self.fuse_dropout_add = args.fuse_dropout_add
        self.fuse_relu_dropout = args.fuse_relu_dropout

        self.self_attn_layer_norm = FusedLayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = FusedLayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = FusedLayerNorm(self.embed_dim)
        self.need_attn = True
Example #2
0
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = DecoderAttention(
            self.embed_dim,
            args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = nn.Dropout(p=args.dropout)
        self.relu_dropout = nn.Dropout(p=args.relu_dropout)
        self.normalize_before = args.decoder_normalize_before
        self.fuse_dropout_add = args.fuse_dropout_add
        self.fuse_relu_dropout = args.fuse_relu_dropout

        self.self_attn_layer_norm = FusedLayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = EncoderAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
                static_kv=True)
            self.encoder_attn_layer_norm = FusedLayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = FusedLayerNorm(self.embed_dim)
        self.need_attn = True
        self.threshold = nn.Threshold(0, 0)
Example #3
0
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.0,
                 bias=False,
                 include_norm_add=False,
                 impl="fast"):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.bias = bias
        self.include_norm_add = include_norm_add
        self.impl = impl
        self.scaling = self.head_dim**-0.5

        self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
        self.in_proj_weight_kv = Parameter(
            torch.Tensor(2 * embed_dim, embed_dim))
        self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        if self.bias:
            assert impl != "fast", "ERROR! The Fast implementation does not support biases!"
            self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
            self.in_proj_bias_kv = Parameter(torch.Tensor(2 * embed_dim))
            self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
        else:
            self.register_parameter("in_proj_bias_q", None)
            self.register_parameter("in_proj_bias_kv", None)
            self.in_proj_bias_q = None
            self.in_proj_bias_kv = None
            self.out_proj_bias = None
        if self.include_norm_add:
            if impl == "fast":
                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm = None
            else:
                self.register_parameter("lyr_norm_gamma_weights", None)
                self.register_parameter("lyr_norm_beta_weights", None)
                self.lyr_nrm_gamma_weights = None
                self.lyr_nrm_beta_weights = None
                self.lyr_nrm = FusedLayerNorm(embed_dim)
        self.reset_parameters()

        if self.include_norm_add:
            if impl == "fast":
                self.attn_func = fast_encdec_attn_norm_add_func
            elif impl == "default":
                self.attn_func = encdec_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)
        else:
            if impl == "fast":
                self.attn_func = fast_encdec_attn_func
            elif impl == "default":
                self.attn_func = encdec_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)
Example #4
0
    def __init__(self, config, img_dim):
        super().__init__()
        self.img_linear = nn.Linear(img_dim, config.hidden_size)
        self.img_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12)
        self.pos_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12)
        self.pos_linear = nn.Linear(7, config.hidden_size)

        # tf naming convention for layer norm
        self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Example #5
0
    def __init__(self, config, img_dim, max_img_seq_len):
        super().__init__()
        self.img_linear = nn.Linear(img_dim, config.hidden_size)
        self.img_LayerNorm = FusedLayerNorm(img_dim, eps=1e-5)
        self.position_embeddings = nn.Embedding(max_img_seq_len,
                                                config.hidden_size)
        self.mask_embedding = nn.Embedding(2, img_dim, padding_idx=0)

        # tf naming convention for layer norm
        self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-5)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    def __init__(self, args):
        super().__init__()
        self.embed_dim = args.encoder_embed_dim
        self.multihead_attn_impl = args.multihead_attn_impl
        if args.multihead_attn_impl == 'fast_with_lyrnrm_and_dropoutadd':
            self.self_attn = SelfMultiheadAttn(
                self.embed_dim,
                args.encoder_attention_heads,
                dropout=args.attention_dropout,
                bias=False,
                include_norm_add=True,
                impl='fast',
            )
        elif args.multihead_attn_impl == 'fast':
            self.self_attn = SelfMultiheadAttn(
                self.embed_dim,
                args.encoder_attention_heads,
                dropout=args.attention_dropout,
                bias=False,
                include_norm_add=False,
                impl='fast',
            )
        else:
            self.self_attn = SelfMultiheadAttn(
                self.embed_dim,
                args.encoder_attention_heads,
                dropout=args.attention_dropout,
                bias=False,
                include_norm_add=False,
                impl='default',
            )

        # in_proj_weight has shape [3 * hidden, hidden] but it should be
        # initialized like a [hidden, hidden] matrix.
        # sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
        # therefore xavier_uniform gain should be set to sqrt(2).
        torch.nn.init.xavier_uniform_(self.self_attn.in_proj_weight,
                                      gain=math.sqrt(2))

        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.encoder_normalize_before
        self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
        self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
        if args.multihead_attn_impl == 'fast_with_lyrnrm_and_dropoutadd':
            self.layer_norms = nn.ModuleList(
                [FusedLayerNorm(self.embed_dim) for i in range(1)])
        else:
            self.layer_norms = nn.ModuleList(
                [FusedLayerNorm(self.embed_dim) for i in range(2)])
Example #7
0
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            self.embed_dim,
            args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        transformer_print(key=mlperf_log.MODEL_HP_RELU_DROPOUT,
                          value=self.relu_dropout)
        self.normalize_before = args.decoder_normalize_before
        self.fuse_dropout_add = args.fuse_dropout_add
        self.fuse_relu_dropout = args.fuse_relu_dropout

        self.self_attn_layer_norm = FusedLayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = FusedLayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = FusedLayerNorm(self.embed_dim)
        self.need_attn = True

        transformer_print(key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE,
                          value={
                              'filter_size': args.decoder_ffn_embed_dim,
                              'activation': 'relu',
                              'use_bias': True
                          })
        transformer_print(key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE,
                          value={
                              'hidden_size': self.embed_dim,
                              'use_bias': True
                          })
        transformer_print(key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
                          value=self.dropout)
        transformer_print(key=mlperf_log.MODEL_HP_NORM, value=self.embed_dim)
Example #8
0
    def __init__(self, config, img_dim):
        super().__init__()
        '''
        获取最后的图像的embeddings, 作为 ImgEncoder 或者 统一的 Encoder 的输入
        embeddings = imageEmd + PosEmd + typeEmb
        '''
        self.config = config
        self.img_linear = nn.Linear(img_dim, config.hidden_size)
        self.img_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12)
        self.pos_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12)
        self.pos_linear = nn.Linear(7, config.hidden_size)
        self.mask_embedding = nn.Embedding(2, img_dim, padding_idx=0)

        # tf naming convention for layer norm
        self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Example #9
0
    def replace_layer_norm(m, name):

        replacable = True
        try:
            # from apex.normalization.fused_layer_norm import FusedLayerNorm
            import importlib
            from apex.normalization.fused_layer_norm import FusedLayerNorm
            fused_layer_norm_cuda = importlib.import_module(
                "fused_layer_norm_cuda")

        except ModuleNotFoundError:
            replacable = False

        if replacable:
            for attr_str in dir(m):
                target_attr = getattr(m, attr_str)
                if type(target_attr) == torch.nn.LayerNorm:
                    setattr(
                        m, attr_str,
                        FusedLayerNorm(
                            target_attr.normalized_shape,
                            eps=target_attr.eps,
                            elementwise_affine=target_attr.elementwise_affine))
            for n, ch in m.named_children():
                replace_layer_norm(ch, n)
Example #10
0
    def __init__(self, args, embed_tokens, no_encoder_attn=False, left_pad=False):
        super().__init__()
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed
        self.fuse_dropout_add = args.fuse_dropout_add
        self.fuse_relu_dropout = args.fuse_relu_dropout

        embed_dim = embed_tokens.embedding_dim
        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_target_positions, embed_dim, padding_idx,
            left_pad=left_pad,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderLayer(args, no_encoder_attn)
            for _ in range(args.decoder_layers)
        ])

        if not self.share_input_output_embed:
            self.embed_out = nn.Parameter(torch.Tensor(args.tgt_vocab_size, embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
        else:
            self.embed_out = self.embed_tokens.weight
        self.normalize = args.decoder_normalize_before
        if self.normalize:
            self.layer_norm = FusedLayerNorm(embed_dim) if args.fuse_layer_norm else nn.LayerNorm(embed_dim)
    def __init__(self, args, dictionary, embed_tokens, left_pad=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.fuse_dropout_add = args.fuse_dropout_add
        self.fuse_relu_dropout = args.fuse_relu_dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            left_pad=left_pad,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args) for i in range(args.encoder_layers)
        ])

        self.normalize = args.encoder_normalize_before
        if self.normalize:
            self.layer_norm = FusedLayerNorm(embed_dim)
Example #12
0
 def __init__(self, config):
     super().__init__()
     self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                             config.hidden_size)
     # self.LayerNorm is not snake-cased to stick with TensorFlow model
     # variable name and be able to load any TensorFlow checkpoint file
     self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-5)
     self.dropout = nn.Dropout(config.hidden_dropout_prob)
Example #13
0
    def __init__(self, d_in, d_out):
        self.d_in = d_in
        self.d_out = d_out
        super().__init__()

        self.norm = torch.nn.LayerNorm(d_in)
        self.norm2 = FusedLayerNorm(d_out)
        # self.norm2 = torch.nn.LayerNorm(d_out)
        self.linear = torch.nn.Linear(d_in, d_out)
        self.linear2 = torch.nn.Linear(d_out, d_out)
Example #14
0
    def __init__(self, config):
        super().__init__()
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model
        # variable name and be able to load any TensorFlow checkpoint file
        self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.gcn = GCN(nfeat=768, nhid=768, noutput=768, dropout=0.5)
Example #15
0
 def __init__(self, args):
     super().__init__()
     self.embed_dim = args.encoder_embed_dim
     self.self_attn = nn.MultiheadAttention(
         self.embed_dim,
         args.encoder_attention_heads,
         dropout=args.attention_dropout,
     )
     self.dropout = args.dropout
     self.relu_dropout = args.relu_dropout
     self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
     self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
     self.layer_norms = nn.ModuleList(
         [FusedLayerNorm(self.embed_dim) for i in range(2)])
Example #16
0
    def __init__(self,
                 args,
                 dictionary,
                 embed_tokens,
                 no_encoder_attn=False,
                 left_pad=False):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed
        self.fuse_dropout_add = args.fuse_dropout_add
        self.fuse_relu_dropout = args.fuse_relu_dropout

        embed_dim = embed_tokens.embedding_dim
        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            padding_idx,
            left_pad=left_pad,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderLayer(args, no_encoder_attn)
            for _ in range(args.decoder_layers)
        ])
        transformer_print(key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
                          value=args.decoder_layers)
        self.adaptive_softmax = None

        if args.adaptive_softmax_cutoff is not None:
            self.adaptive_softmax = AdaptiveSoftmax(
                len(dictionary),
                args.decoder_embed_dim,
                options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
                dropout=args.dropout)
        elif not self.share_input_output_embed:
            self.embed_out = nn.Parameter(
                torch.Tensor(len(dictionary), embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=embed_dim**-0.5)
        self.normalize = args.decoder_normalize_before
        if self.normalize:
            self.layer_norm = FusedLayerNorm(embed_dim)
Example #17
0
    def __init__(self,
                 args,
                 dictionary,
                 embed_tokens,
                 no_encoder_attn=False,
                 left_pad=False):
        super().__init__(dictionary)
        self.dropout = nn.Dropout(p=args.dropout)
        self.share_input_output_embed = args.share_decoder_input_output_embed
        self.fuse_dropout_add = args.fuse_dropout_add
        self.fuse_relu_dropout = args.fuse_relu_dropout

        embed_dim = embed_tokens.embedding_dim
        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = Scale(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            padding_idx,
            left_pad=left_pad,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderLayer(args, no_encoder_attn)
            for _ in range(args.decoder_layers)
        ])
        self.adaptive_softmax = None

        if args.adaptive_softmax_cutoff is not None:
            self.adaptive_softmax = AdaptiveSoftmax(
                len(dictionary),
                args.decoder_embed_dim,
                options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
                dropout=args.dropout)
        else:
            self.embed_out = nn.Linear(embed_dim, len(dictionary), bias=False)
            nn.init.normal_(self.embed_out.weight, mean=0, std=embed_dim**-0.5)
        self.normalize = args.decoder_normalize_before
        if self.normalize:
            self.layer_norm = FusedLayerNorm(embed_dim)
Example #18
0
 def __init__(self, args):
     super().__init__()
     self.embed_dim = args.encoder_embed_dim
     self.self_attn = MultiheadAttention(
         self.embed_dim,
         args.encoder_attention_heads,
         dropout=args.attention_dropout,
         softmax_type=args.softmax_type,
     )
     self.dropout = args.dropout
     self.relu_dropout = args.relu_dropout
     self.fuse_dropout_add = args.fuse_dropout_add
     self.fuse_relu_dropout = args.fuse_relu_dropout
     self.normalize_before = args.encoder_normalize_before
     self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
     self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
     self.layer_norms = nn.ModuleList(
         [FusedLayerNorm(self.embed_dim) for i in range(2)])
Example #19
0
 def __init__(self, args):
     super().__init__()
     self.embed_dim = args.encoder_embed_dim
     self.self_attn = EncoderAttention(
         self.embed_dim,
         args.encoder_attention_heads,
         dropout=args.attention_dropout,
     )
     self.dropout = nn.Dropout(p=args.dropout)
     self.relu_dropout = nn.Dropout(p=args.relu_dropout)
     self.fuse_dropout_add = args.fuse_dropout_add
     self.fuse_relu_dropout = args.fuse_relu_dropout
     self.normalize_before = args.encoder_normalize_before
     self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
     self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
     self.layer_norms = nn.ModuleList(
         [FusedLayerNorm(self.embed_dim) for i in range(2)])
     self.threshold = nn.Threshold(0, 0)
Example #20
0
    def __init__(self, config):
        super().__init__()
        '''
        获取最后的文本的embeddings, 作为TextEncoder 或者 统一的Encoder 的输入
        embeddings = tokenEmd + PosEmd + tokentypeEmb
        如果采用双流的模型的话,那么这里type embedding可以去掉。
        '''
        self.config = config
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size,
                                            padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        if config.model_mode != 'two_flow':
            self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                      config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model
        # variable name and be able to load any TensorFlow checkpoint file
        self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Example #21
0
    def replace_layer_norm(m, name):

        replacable = True
        try:
            from apex.normalization.fused_layer_norm import FusedLayerNorm
        except ImportError:
            replacable = False

        if replacable:
            for attr_str in dir(m):
                target_attr = getattr(m, attr_str)
                if type(target_attr) == torch.nn.LayerNorm:
                    setattr(
                        m, attr_str,
                        FusedLayerNorm(
                            target_attr.normalized_shape,
                            eps=target_attr.eps,
                            elementwise_affine=target_attr.elementwise_affine))
                    # setattr(m, attr_str,
                    #         SynchronizedBatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum,
                    #                                 target_attr.affine))
            for n, ch in m.named_children():
                replace_layer_norm(ch, n)
def check_ln_speed(use_apex, nbatch, nchannel, eps, nrepeat):
    B, C = nbatch, nchannel
    # WarmUp
    for _ in range(2):
        in_data = th.randn(B, C, device=device, dtype=dtype)
        out_data = in_data * in_data
        npy_out_data = out_data.cpu().numpy()
    if not use_apex:
        layer = nn.LayerNorm(in_data.size()[1:], eps=eps)
    else:
        layer = FusedLayerNorm(in_data.size()[1:], eps=eps)
    if args.use_gpu:
        layer.cuda(device)
    if dtype == th.float16:
        layer.half()
    th.cuda.synchronize()
    fwd_time = 0
    bwd_time = 0
    for i in range(nrepeat):
        in_data = th.randn(B, C, device=device, dtype=dtype, requires_grad=True)
        ograd = th.randn(B, C, device=device, dtype=dtype)
        npy_in_data = in_data.cpu().detach().numpy().astype(np.float64)
        gt_out = (npy_in_data - npy_in_data.mean(axis=-1, keepdims=True)) \
                 / np.sqrt(npy_in_data.var(axis=-1, keepdims=True) + eps)
        th.cuda.synchronize()

        # Profile Forward + Backward
        with th.enable_grad():
            th.cuda.synchronize()
            start = time.time()
            out_data = layer(in_data)
            th.cuda.synchronize()
            if i > 0:
                fwd_time += time.time() - start
            start = time.time()
            out_data.backward([ograd])
            th.cuda.synchronize()
            if i > 0:
                bwd_time += time.time() - start
        npy_th_out_data = out_data.cpu().detach().numpy()
        if dtype != th.float16:
            npt.assert_allclose(npy_th_out_data, gt_out.astype(args.dtype), 1E-4, 1E-4)
        else:
            npt.assert_allclose(npy_th_out_data, gt_out.astype(args.dtype), 1E-2, 1E-2)
    return fwd_time / nrepeat * 1000000, bwd_time / nrepeat * 1000000
def LayerNorm(embedding_dim):
    m = FusedLayerNorm(embedding_dim)
    return m
Example #24
0
 def __init__(self, config):
     super(QueryFeatEmbeddings, self).__init__()
     self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                             config.hidden_size)
     self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-5)
     self.dropout = nn.Dropout(config.hidden_dropout_prob)
Example #25
0
 def __init__(self, embed_dim, normalize_before, fuse=True):
     super().__init__()
     self.embed_dim = embed_dim
     self.normalize_before = normalize_before
     self.ln = FusedLayerNorm(embed_dim) if fuse else nn.LayerNorm(
         embed_dim)
Example #26
0
class EncdecMultiheadAttn(nn.Module):
    """Multi-headed attention.

    See "Attention Is All You Need" for more details.
    """
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 bias=False,
                 include_norm_add=False,
                 impl='fast'):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.bias = bias
        self.include_norm_add = include_norm_add
        self.impl = impl
        self.scaling = self.head_dim**-0.5

        self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
        self.in_proj_weight_kv = Parameter(
            torch.Tensor(2 * embed_dim, embed_dim))
        self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        if self.bias:
            assert impl != 'fast', "ERROR! The Fast implementation does not support biases!"
            self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
            self.in_proj_bias_kv = Parameter(torch.Tensor(2 * embed_dim))
            self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
        else:
            self.register_parameter('in_proj_bias_q', None)
            self.register_parameter('in_proj_bias_kv', None)
            self.in_proj_bias_q = None
            self.in_proj_bias_kv = None
            self.out_proj_bias = None
        if self.include_norm_add:
            if impl == 'fast':
                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm = None
            else:
                self.register_parameter('lyr_norm_gamma_weights', None)
                self.register_parameter('lyr_norm_beta_weights', None)
                self.lyr_nrm_gamma_weights = None
                self.lyr_nrm_beta_weights = None
                self.lyr_nrm = FusedLayerNorm(embed_dim)
        self.reset_parameters()

        if self.include_norm_add:
            if impl == 'fast': self.attn_func = fast_encdec_attn_norm_add_func
            elif impl == 'default': self.attn_func = encdec_attn_func
            else: assert False, "Unsupported impl: {} !".format(impl)
        else:
            if impl == 'fast': self.attn_func = fast_encdec_attn_func
            elif impl == 'default': self.attn_func = encdec_attn_func
            else: assert False, "Unsupported impl: {} !".format(impl)

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.in_proj_weight_q)
        nn.init.xavier_uniform_(self.in_proj_weight_kv)
        nn.init.xavier_uniform_(self.out_proj_weight)
        if self.bias:
            nn.init.constant_(self.in_proj_bias_q, 0.)
            nn.init.constant_(self.in_proj_bias_kv, 0.)
            nn.init.constant_(self.out_proj_bias, 0.)
        if self.include_norm_add:
            if self.impl == 'fast':
                nn.init.ones_(self.lyr_nrm_gamma_weights)
                nn.init.zeros_(self.lyr_nrm_beta_weights)
            else:
                self.lyr_nrm.reset_parameters()

    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                need_weights=False,
                attn_mask=None,
                is_training=True):
        """Input shape: Time x Batch x Channel

        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Future timesteps can be masked with the
        `mask_future_timesteps` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """

        if key_padding_mask is not None:
            assert (
                attn_mask is None
            ), "ERROR attn_mask and key_padding_mask should not be both defined!"
            mask = key_padding_mask
        elif attn_mask is not None:
            mask = attn_mask
        else:
            mask = None

        if self.include_norm_add:
            if self.impl == 'fast':
                outputs = self.attn_func(
                    attn_mask is not None, is_training, self.num_heads, query,
                    key, self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
                    self.in_proj_weight_q, self.in_proj_weight_kv,
                    self.out_proj_weight, mask, self.dropout)
            else:
                lyr_nrm_results = self.lyr_nrm(query)
                outputs = self.attn_func(
                    attn_mask is not None, is_training, self.num_heads,
                    self.scaling, lyr_nrm_results, key, self.in_proj_weight_q,
                    self.in_proj_weight_kv, self.out_proj_weight,
                    self.in_proj_bias_q, self.in_proj_bias_kv,
                    self.out_proj_bias, mask, self.dropout)
                if is_training:
                    outputs = jit_dropout_add(outputs, query, self.dropout,
                                              is_training)
                else:
                    outputs = outputs + query
        else:
            if self.impl == 'fast':
                outputs = self.attn_func(attn_mask is not None, is_training,
                                         self.num_heads, query, key,
                                         self.in_proj_weight_q,
                                         self.in_proj_weight_kv,
                                         self.out_proj_weight, mask,
                                         self.dropout)
            else:
                outputs = self.attn_func(
                    attn_mask is not None, is_training, self.num_heads,
                    self.scaling, query, key, self.in_proj_weight_q,
                    self.in_proj_weight_kv, self.out_proj_weight,
                    self.in_proj_bias_q, self.in_proj_bias_kv,
                    self.out_proj_bias, mask, self.dropout)

        return outputs, None
Example #27
0
class SelfMultiheadAttn(nn.Module):
    """Multi-headed attention.

    See "Attention Is All You Need" for more details.
    """
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=False,
        include_norm_add=False,
        impl="fast",
        separate_qkv_params=False,
        mask_additive=False,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.bias = bias
        self.include_norm_add = include_norm_add
        self.impl = impl
        self.scaling = self.head_dim**-0.5
        self.separate_qkv_params = separate_qkv_params
        self.mask_additive = mask_additive
        if mask_additive:
            assert self.include_norm_add == False, "additive mask not supported with layer norm"
            assert impl == "default" or (
                impl == "fast" and
                bias), "additive mask not supported for fast mode without bias"
        if separate_qkv_params:
            self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        else:
            self.in_proj_weight = Parameter(
                torch.Tensor(3 * embed_dim, embed_dim))
        self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        if self.bias:
            if separate_qkv_params:
                self.q_bias = Parameter(torch.Tensor(embed_dim))
                self.k_bias = Parameter(torch.Tensor(embed_dim))
                self.v_bias = Parameter(torch.Tensor(embed_dim))
            else:
                self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
            self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
        else:
            if separate_qkv_params:
                self.register_parameter("q_bias", None)
                self.register_parameter("k_bias", None)
                self.register_parameter("v_bias", None)
                self.q_bias = None
                self.k_bias = None
                self.v_bias = None
            else:
                self.register_parameter("in_proj_bias", None)
                self.in_proj_bias = None
            self.register_parameter("out_proj_bias", None)
            self.out_proj_bias = None
        if self.include_norm_add:
            if impl == "fast":
                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm = None
            else:
                self.register_parameter("lyr_norm_gamma_weights", None)
                self.register_parameter("lyr_norm_beta_weights", None)
                self.lyr_nrm_gamma_weights = None
                self.lyr_nrm_beta_weights = None
                self.lyr_nrm = FusedLayerNorm(embed_dim)
        self.reset_parameters()

        if self.include_norm_add:
            if impl == "fast":
                self.attn_func = fast_self_attn_norm_add_func
            elif impl == "default":
                self.attn_func = self_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)
        else:
            if impl == "fast":
                self.attn_func = fast_self_attn_func
            elif impl == "default":
                self.attn_func = self_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)

    def reset_parameters(self):
        if self.separate_qkv_params:
            nn.init.xavier_uniform_(self.q_weight)
            nn.init.xavier_uniform_(self.k_weight)
            nn.init.xavier_uniform_(self.v_weight)
        else:
            # in_proj_weight has shape [3 * hidden, hidden] but it should be
            # initialized like a [hidden, hidden] matrix.
            # sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
            # therefore xavier_uniform gain should be set to sqrt(2).
            nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
        nn.init.xavier_uniform_(self.out_proj_weight)
        if self.bias:
            if self.separate_qkv_params:
                nn.init.constant_(self.q_bias, 0.0)
                nn.init.constant_(self.k_bias, 0.0)
                nn.init.constant_(self.v_bias, 0.0)
            else:
                nn.init.constant_(self.in_proj_bias, 0.0)
            nn.init.constant_(self.out_proj_bias, 0.0)
        if self.include_norm_add:
            if self.impl == "fast":
                nn.init.ones_(self.lyr_nrm_gamma_weights)
                nn.init.zeros_(self.lyr_nrm_beta_weights)
            else:
                self.lyr_nrm.reset_parameters()

    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                need_weights=False,
                attn_mask=None,
                is_training=True):
        """Input shape: Time x Batch x Channel

        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Future timesteps can be masked with the
        `mask_future_timesteps` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        if self.separate_qkv_params:
            input_weights = (torch.cat(
                [
                    self.q_weight.view(self.num_heads, 1, self.head_dim,
                                       self.embed_dim),
                    self.k_weight.view(self.num_heads, 1, self.head_dim,
                                       self.embed_dim),
                    self.v_weight.view(self.num_heads, 1, self.head_dim,
                                       self.embed_dim),
                ],
                dim=1,
            ).reshape(3 * self.embed_dim, self.embed_dim).contiguous())
        else:
            input_weights = self.in_proj_weight
        if self.bias:
            if self.separate_qkv_params:
                input_bias = (torch.cat(
                    [
                        self.q_bias.view(self.num_heads, 1, self.head_dim),
                        self.k_bias.view(self.num_heads, 1, self.head_dim),
                        self.v_bias.view(self.num_heads, 1, self.head_dim),
                    ],
                    dim=1,
                ).reshape(3 * self.embed_dim).contiguous())
            else:
                input_bias = self.in_proj_bias
        else:
            input_bias = None
        if key_padding_mask is not None:
            assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!"
            mask = key_padding_mask
        elif attn_mask is not None:
            assert self.mask_additive == False, "additive mask not supported for time mask"
            mask = attn_mask
        else:
            mask = None

        if self.include_norm_add:
            if self.impl == "fast":
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    query,
                    self.lyr_nrm_gamma_weights,
                    self.lyr_nrm_beta_weights,
                    input_weights,
                    self.out_proj_weight,
                    mask,
                    self.dropout,
                )
            else:
                lyr_nrm_results = self.lyr_nrm(query)
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    self.scaling,
                    lyr_nrm_results,
                    input_weights,
                    self.out_proj_weight,
                    input_bias,
                    self.out_proj_bias,
                    mask,
                    self.mask_additive,
                    self.dropout,
                )
                if is_training:
                    outputs = jit_dropout_add(outputs, query, self.dropout,
                                              is_training)
                else:
                    outputs = outputs + query
        else:
            if self.impl == "fast":
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    query,
                    input_weights,
                    self.out_proj_weight,
                    input_bias,
                    self.out_proj_bias,
                    mask,
                    self.mask_additive,
                    self.dropout,
                )
            else:
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    self.scaling,
                    query,
                    input_weights,
                    self.out_proj_weight,
                    input_bias,
                    self.out_proj_bias,
                    mask,
                    self.mask_additive,
                    self.dropout,
                )

        return outputs, None
Example #28
0
 def __init__(self, hidden_size, feat_dim):
     super().__init__()
     self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size), GELU(),
                              FusedLayerNorm(hidden_size, eps=1e-5),
                              nn.Linear(hidden_size, feat_dim))
Example #29
0
 def __init__(self, norm_size):
     super(ApexLayerNorm, self).__init__()
     self.layer_norm = FusedLayerNorm(norm_size, eps=1e-12)
Example #30
0
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=False,
        include_norm_add=False,
        impl="fast",
        separate_qkv_params=False,
        mask_additive=False,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.bias = bias
        self.include_norm_add = include_norm_add
        self.impl = impl
        self.scaling = self.head_dim**-0.5
        self.separate_qkv_params = separate_qkv_params
        self.mask_additive = mask_additive
        if mask_additive:
            assert self.include_norm_add == False, "additive mask not supported with layer norm"
            assert impl == "default" or (
                impl == "fast" and
                bias), "additive mask not supported for fast mode without bias"
        if separate_qkv_params:
            self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        else:
            self.in_proj_weight = Parameter(
                torch.Tensor(3 * embed_dim, embed_dim))
        self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        if self.bias:
            if separate_qkv_params:
                self.q_bias = Parameter(torch.Tensor(embed_dim))
                self.k_bias = Parameter(torch.Tensor(embed_dim))
                self.v_bias = Parameter(torch.Tensor(embed_dim))
            else:
                self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
            self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
        else:
            if separate_qkv_params:
                self.register_parameter("q_bias", None)
                self.register_parameter("k_bias", None)
                self.register_parameter("v_bias", None)
                self.q_bias = None
                self.k_bias = None
                self.v_bias = None
            else:
                self.register_parameter("in_proj_bias", None)
                self.in_proj_bias = None
            self.register_parameter("out_proj_bias", None)
            self.out_proj_bias = None
        if self.include_norm_add:
            if impl == "fast":
                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm = None
            else:
                self.register_parameter("lyr_norm_gamma_weights", None)
                self.register_parameter("lyr_norm_beta_weights", None)
                self.lyr_nrm_gamma_weights = None
                self.lyr_nrm_beta_weights = None
                self.lyr_nrm = FusedLayerNorm(embed_dim)
        self.reset_parameters()

        if self.include_norm_add:
            if impl == "fast":
                self.attn_func = fast_self_attn_norm_add_func
            elif impl == "default":
                self.attn_func = self_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)
        else:
            if impl == "fast":
                self.attn_func = fast_self_attn_func
            elif impl == "default":
                self.attn_func = self_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)