예제 #1
0
    def forward(inputs, double_precision, dropout_prob, is_training, heads):
        """
        :param heads: 
        :param is_training: 
        :param dropout_prob: 
        :param inputs: 
        :param double_precision: 
        :return: 
        """

        len_k = inputs.size(-1)
        if mask_softmax_dropout_cuda and len_k <= 2048 and inputs.type() == 'torch.cuda.HalfTensor':

            dropout_mask, softmax_results, dropout_results = \
                mask_softmax_dropout_cuda.forward(is_training, heads, inputs, dropout_prob)

            if is_training:
                dropout_results = softmax_results

        else:
            dtype_ = torch.float64 if double_precision else torch.float32
            softmax_results = F.softmax(inputs, dim=-1, dtype=dtype_)

            # Dropout - is not executed for inference
            if is_training:
                dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1. - dropout_prob_t[0]))
            else:
                dropout_results = softmax_results
                dropout_mask = torch.tensor([])

        return dropout_mask, softmax_results, dropout_results
예제 #2
0
    def forward(ctx, inputs, pos, use_time_mask, is_training, heads,
                input_weights, output_weights, pos_weights, input_biases,
                output_biases, pos_biases, r_w_bias, r_r_bias, mask,
                dropout_prob, incremental, incremental_cache, double_precision,
                learnable_pos, return_coverage, recompute):
        """
        :param recompute:
        :param return_coverage:
        :param learnable_pos:
        :param double_precision: ops at float64, only for debugging
        :param ctx: context object to stash information for backward
        :param inputs: input hidden states [len_q x batch_size x hidden]
        :param pos: [len_k x 1 x hidden]
        :param use_time_mask: bool, if we use the causal mask for decoder
        :param is_training: training state, for dropout
        :param heads: number of heads
        :param input_weights: weight matrix [hidden x 3*hidden]
        :param output_weights: output weight [hidden x hidden]
        :param input_biases: bias [3*hidden]
        :param output_biases: output bias [bias]
        :param pos_biases:
        :param pos_weights:
        :param r_w_bias:
        :param r_r_bias:
        :param mask: None or [B x T] or [T x T]
        :param dropout_prob:
        :param incremental:
        :param incremental_cache:
        :return:
        """

        heads_t = torch.tensor([heads])
        dropout_prob_t = torch.tensor([dropout_prob])
        null_tensor = torch.tensor([]).to(inputs.device)
        head_dim = inputs.size(2) // heads
        scale_t = torch.tensor([head_dim**-0.5])
        ctx.double_precision = double_precision
        ctx.fused_softmax_dropout = False
        ctx.learnable_pos = learnable_pos
        ctx.return_coverage = return_coverage
        ctx.fused_all = False
        ctx.recompute = recompute

        bsz, len_q = inputs.size(1), inputs.size(0)
        len_r = pos.size(
            0
        )  # r can be longer than query, i.e for bidirectional attention we need 2k+1 positions
        len_k = len_q  # because of self-attention

        if mask is not None:
            mask = mask.to(torch.bool)
            # Self Attention Time Mask
            if use_time_mask:
                assert (len(mask.size()) == 2), "Timing mask is not 2D!"
                # assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
                mask = mask.unsqueeze(0).unsqueeze(0)
            # Key Padding Mask
            else:
                # attn_score = attn_score.view(bsz, heads, len_q, len_k)
                mask = mask.unsqueeze(1).unsqueeze(2)

        if rel_self_attn_cuda is not None and not incremental and len_k <= 2048 and \
                inputs.type() == 'torch.cuda.HalfTensor' and learnable_pos:

            input_lin_results, rr_head_q, rw_head_q, \
            softmax_results, dropout_results, dropout_mask, \
            matmul2_results, outputs \
                = rel_self_attn_cuda.forward(is_training, heads, inputs, pos,
                                             input_weights, output_weights,
                                             input_biases, output_biases,
                                             r_w_bias, r_r_bias,
                                             mask, dropout_prob)

            pos_lin_results = None
            r_head_k = None
            nan_mask = null_tensor

            if recompute:
                ctx.save_for_backward(heads_t, scale_t, inputs, pos, r_head_k,
                                      input_weights, pos_weights,
                                      output_weights, input_biases, pos_biases,
                                      output_biases, r_w_bias, r_r_bias,
                                      dropout_mask, nan_mask, mask,
                                      dropout_prob_t)
            else:
                ctx.save_for_backward(heads_t, scale_t, matmul2_results,
                                      dropout_results, softmax_results,
                                      input_lin_results, pos_lin_results,
                                      rw_head_q, rr_head_q, inputs, pos,
                                      r_head_k, input_weights, pos_weights,
                                      output_weights, dropout_mask, nan_mask,
                                      dropout_prob_t)

            ctx.fused_all = True
            if return_coverage:
                return (outputs, dropout_results)
            else:
                return outputs

        if pos.size(1) == 1 and not learnable_pos:
            pos = pos.repeat(
                1, bsz, 1
            )  # we have to use repeat instead of expand here because mm needs contiguous

        # Input Linear GEMM
        # input1: (activations) [len_q, bsz, hidden]
        # input2: (weights)     [hidden*3 (3072), hidden (1024)] (transpose [0,1])
        # output:               [len_q, bsz, hidden*3]
        # GEMM: ( (len_q*bsz) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (len_q*bsz x embed_dim*3)
        input_lin_results = torch.addmm(input_biases,
                                        inputs.view(
                                            inputs.size(0) * inputs.size(1),
                                            inputs.size(2)),
                                        input_weights.transpose(0, 1),
                                        beta=1.,
                                        alpha=1.)

        # reshape [len_q*bsz, embed_dim*3 -> len_q x bsz x embed_dim*3]
        input_lin_results = input_lin_results.view(inputs.size(0),
                                                   inputs.size(1),
                                                   input_weights.size(0))
        # check = torch.allclose(input_lin_results, input_lin_results2)
        # print("Check linear in", check)

        if not learnable_pos:
            pos_lin_results = torch.addmm(pos_biases,
                                          pos.view(
                                              pos.size(0) * pos.size(1),
                                              pos.size(2)),
                                          pos_weights.transpose(0, 1),
                                          beta=1.,
                                          alpha=1.)

            pos_lin_results = pos_lin_results.view(pos.size(0), pos.size(1),
                                                   pos_weights.size(0))

            r_head_k = pos_lin_results.view(pos.size(0), bsz * heads,
                                            head_dim)  # T x BxH x D
        else:
            # pos_lin_results = pos.view(pos.size(0), bsz * heads, head_dim)  # T x BxH x D
            # r_head_k = pos_lin_results
            pos_lin_results = None
            r_head_k = None

        # Slice out q,k,v from one big Input Linear output (should only impact meta data, no copies!)
        # Sequences and heads are combined to make the batch of the Batched GEMM
        # input_lin_results: [len_q, bsz, heads(16), 3, head_dim(64)]
        # input_lin_results: [len_q, batches=bsz*heads, 3, head_dim]
        input_lin_results = input_lin_results.view(inputs.size(0),
                                                   inputs.size(1) * heads, 3,
                                                   head_dim)
        queries = input_lin_results[:, :, 0, :]
        keys = input_lin_results[:, :, 1, :]
        values = input_lin_results[:, :, 2, :]

        if incremental:
            # We have to change the heads x head_dim first and then concat to the T dim
            # bsz is changed during translation due to beam search
            # during translation we want to keep the actual T dim in MM as 1 constantly
            keys = keys.reshape(len_q, bsz, heads * head_dim)
            values = values.reshape(len_q, bsz, heads * head_dim)

            if 'k' in incremental_cache and 'v' in incremental_cache:
                keys = torch.cat([incremental_cache['k'], keys],
                                 dim=0)  # time first
                incremental_cache['k'] = keys
                values = torch.cat([incremental_cache['v'], values],
                                   dim=0)  # time first
                incremental_cache['v'] = values
            else:
                incremental_cache['k'] = keys
                incremental_cache['v'] = values

            keys = keys.view(-1, bsz * heads, head_dim)
            values = values.view(-1, bsz * heads, head_dim)
            # re-update len_k to be the newly updated length of the keys
            len_k = keys.size(0)
        # Relative Attention from here:
        # r_w_bias size: head * head_dim
        rw_head_q = queries.view(len_q, bsz, heads, head_dim) + r_w_bias  #
        rw_head_q = rw_head_q.view(len_q, bsz * heads, head_dim)

        # matmul_ac batched GEMMs
        # queries+bias: [len_q, bsz*heads, head_dim] transpose(0, 1)
        # keys: [len_k, bsz*heads, head_dim] transpose(0, 1)
        if queries.is_cuda:
            matmul_ac = torch.empty(
                (bsz * heads, queries.size(0), keys.size(0)),
                dtype=queries.dtype,
                device=rw_head_q.device)
            matmul_ac = torch.baddbmm(matmul_ac,
                                      rw_head_q.transpose(0, 1),
                                      keys.transpose(0, 1).transpose(1, 2),
                                      out=matmul_ac,
                                      beta=0.0,
                                      alpha=scale_t[0])
        else:
            matmul_ac = torch.bmm(rw_head_q.transpose(0, 1),
                                  keys.transpose(0, 1).transpose(1, 2)).mul_(
                                      scale_t[0])

        rr_head_q = queries.view(len_q, bsz, heads, head_dim) + r_r_bias  #
        # check = torch.allclose(rr_head_q.view(len_q, bsz, -1), rr_head_q2, rtol=1e-03, atol=1e-04)
        # print("Check rr_head_q", check)
        rr_head_q = rr_head_q.view(len_q, bsz * heads, head_dim)

        if not learnable_pos:
            if queries.is_cuda:
                # matmul2 batched GEMMs
                # queries+bias: [len_q, bsz*heads, head_dim] transpose(0, 1)
                # rel_positions: [len_r, bsz*heads, head_dim] transpose(0, 1)
                matmul_bd = torch.empty((bsz * heads, queries.size(0), len_r),
                                        dtype=queries.dtype,
                                        device=rw_head_q.device)
                matmul_bd = torch.baddbmm(matmul_bd,
                                          rr_head_q.transpose(0, 1),
                                          r_head_k.transpose(0, 1).transpose(
                                              1, 2),
                                          out=matmul_bd,
                                          beta=0.0,
                                          alpha=scale_t[0])
            else:
                matmul_bd = torch.matmul(rr_head_q.transpose(0, 1), r_head_k.transpose(0, 1).transpose(1, 2)) \
                    .mul_(scale_t[0])

            # shift so that the relative positions are aligned
            # the first element will have 0 -1 ... -n relative positions compared to other elements
            # the last element will have  n-1 n-2 ...  0
            matmul_bd = RelativeShift.forward(matmul_bd, True, False)

            # if len_r is longer than len_k, then we need to take the first len_k positions only
            matmul_bd = matmul_bd[:, :, :len_k]

            attn_score = matmul_ac + matmul_bd  # both AC and BD are scaled with scale_t before in baddbmm
        else:
            # matmul2 batched GEMMs
            # queries+bias: [len_q, bsz*heads, head_dim]
            # rel_positions: [len_q, len_k, head_dim] transpose(1, 2)
            # add directly into matmul_ac so we don't need to
            # torch.baddbmm(matmul_ac.transpose(0, 1), rr_head_q, pos.transpose(1, 2),
            # out=matmul_ac.transpose(0, 1), beta=1.0, alpha=scale_t[0])
            matmul_ac.transpose(0, 1).baddbmm_(rr_head_q,
                                               pos.transpose(1, 2),
                                               beta=1.0,
                                               alpha=scale_t[0])
            attn_score = matmul_ac
            # no need to shift in this case

        # attn_score should have size [bsz*heads, len_q, len_k] for now

        if mask is not None:
            attn_score.view(bsz, heads, len_q,
                            len_k).masked_fill_(mask, float('-inf'))

        if not (mask_softmax_dropout_cuda is not None and len_k <= 2048
                and attn_score.type()
                == 'torch.cuda.HalfTensor') or double_precision:

            dtype_ = torch.float64 if double_precision else torch.float32
            softmax_results = F.softmax(attn_score, dim=-1).type_as(attn_score)

            # Dropout - is not executed for inference
            if is_training:
                dropout_results, dropout_mask = torch._fused_dropout(
                    softmax_results, p=(1. - dropout_prob_t[0]))
            else:
                dropout_results = softmax_results
                dropout_mask = null_tensor
            ctx.fused_softmax_dropout = False
        else:
            # Fused Softmax and Dropout
            # ASSERTED To produce the same result with F.softmax
            dropout_mask, softmax_results, dropout_results = \
                mask_softmax_dropout_cuda.forward(is_training, heads, attn_score, dropout_prob_t[0])

            if not is_training:
                dropout_results = softmax_results

            # Verification
            # softmax_results_ref = F.softmax(attn_score, dim=-1)
            # if is_training:
            #     dropout_results_ref = softmax_results_ref * dropout_mask.half() * (1 / (1 - dropout_prob_t[0]))
            # else:
            #     dropout_results_ref = softmax_results_ref
            #
            # comp = torch.allclose(softmax_results_ref, softmax_results, rtol=1e-03, atol=1e-04)
            # comp = torch.allclose(dropout_results_ref, dropout_results, rtol=1e-03, atol=1e-04)
            # if comp:
            #     print("Forward pass verification passed.")
            # else:
            #     print("ERROR: Forward pass verification failed")
            # print(dropout_results - dropout_results_ref)
            # print(softmax_results)
            # Done Verification

            ctx.fused_softmax_dropout = True

        nan_mask = null_tensor
        # nan_mask = torch.isnan(softmax_results)
        # if nan_mask.any():
        #     softmax_results.masked_fill_(nan_mask, 0)

        # Matmul2 Batched GEMMs
        # Input1: from_softmax [bsz*heads, len_q, seql_k]
        # Input2: (values)     [seql_v, bsz*heads, head_dim] transpose(0,1)
        # Output:              [bsz*heads, len_q, head_dim]
        # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = (len_q x head_dim)
        matmul2_results = torch.bmm(dropout_results,
                                    values.transpose(0, 1)).transpose(0, 1)
        # if learnable_pos:
        #     # Input1: from_softmax [bsz*heads, len_q, seql_k].transpose(0, 1)
        #     # Input2: R [len_q, len_k, head_dim]
        #     # Output: [ len_q, bsz*heads, head_dim]
        #     torch.baddbmm(matmul2_results, dropout_results.transpose(0, 1), pos, beta=1.0, alpha=1.0,
        #                   out=matmul2_results)

        matmul2_results = matmul2_results.contiguous().view(
            inputs.size(0), inputs.size(1), inputs.size(2))

        # Output Linear GEMM
        # Input1: (activations) [len_q, bsz, embed_dim=heads*head_dim]
        # Input2: (weights)     [ embed_dim, embed_dim ] transpose(0,1)
        # Output:               [ len_q, bsz, embed_dim ]
        # GEMM: ( len_q*bsz x embed_dim ) x ( embed_dim x embed_dim ) = ( len_q*bsz x embed_dim )
        outputs = torch.addmm(output_biases,
                              matmul2_results.view(
                                  inputs.size(0) * inputs.size(1),
                                  inputs.size(2)),
                              output_weights.transpose(0, 1),
                              beta=1.,
                              alpha=1.)

        outputs = outputs.view(inputs.size(0), inputs.size(1),
                               output_weights.size(0))

        if recompute:
            ctx.save_for_backward(heads_t, scale_t, inputs, pos, r_head_k,
                                  input_weights, pos_weights, output_weights,
                                  input_biases, pos_biases, output_biases,
                                  r_w_bias, r_r_bias, dropout_mask, nan_mask,
                                  mask, dropout_prob_t)

            # delete stuff here
            del input_lin_results, queries, keys, values
            del matmul_ac, matmul2_results, attn_score, softmax_results, dropout_results
            del rr_head_q, rw_head_q
            if not learnable_pos:
                del matmul_bd

            dropout_results = null_tensor

        else:
            ctx.save_for_backward(heads_t, scale_t, matmul2_results,
                                  dropout_results, softmax_results,
                                  input_lin_results, pos_lin_results,
                                  rw_head_q, rr_head_q, inputs, pos, r_head_k,
                                  input_weights, pos_weights, output_weights,
                                  dropout_mask, nan_mask, dropout_prob_t)

            del attn_score

        if return_coverage:
            return (outputs, dropout_results)
        else:
            return outputs
예제 #3
0
    def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv,
                input_weights_q, input_weights_kv, output_weights,
                mask, dropout_prob,
                incremental, incremental_cache,
                double_precision, return_coverage):
        heads_t = torch.tensor([heads])
        dropout_prob_t = torch.tensor([dropout_prob])
        null_tensor = torch.tensor([])
        head_dim = inputs_q.size(2) // heads
        scale_t = torch.tensor([head_dim ** -0.5])
        use_mask = (mask is not None)

        bsz, len_q, len_k = inputs_q.size(1), inputs_q.size(0), inputs_kv.size(0)
        ctx.incremental = incremental
        ctx.fused_softmax_dropout = False
        ctx.fused_all = False
        ctx.len_q = len_q
        ctx.len_k = len_k
        ctx.double_precision = double_precision
        ctx.return_coverage = return_coverage

        if mask is not None:
            # Self Attention Pad Mask
            mask = mask.to(torch.bool)

            if len(mask.shape) == 3:
                mask = mask.unsqueeze(1)  # for the head dimension
            else:
                mask = mask.unsqueeze(1).unsqueeze(2)  # for the head and query dimension

        if encdec_multihead_attn_cuda is not None and not incremental and len_k <= 2048\
                and inputs_q.type() == 'torch.cuda.HalfTensor':
            input_lin_q_results, input_lin_kv_results, \
                softmax_results, dropout_results, dropout_mask, \
                matmul2_results, outputs \
                = encdec_multihead_attn_cuda.forward(is_training, heads, inputs_q, inputs_kv,
                                                     input_weights_q, input_weights_kv,
                                                     output_weights, mask, dropout_prob)

            ctx.save_for_backward(heads_t,
                                  scale_t,
                                  matmul2_results,
                                  dropout_results,
                                  softmax_results,
                                  input_lin_q_results,
                                  input_lin_kv_results,
                                  inputs_q,
                                  inputs_kv,
                                  input_weights_q,
                                  input_weights_kv,
                                  output_weights,
                                  dropout_mask,
                                  dropout_prob_t)
            ctx.fused_all = True

            if return_coverage:
                return outputs, softmax_results
            else:
                return (outputs, )

        # Input Linear GEMM Q
        # input1: (activations) [seql_q, bsz, embed_dim] -> [len_q * bsz, embed_dim]
        # input2: (weights)     [embed_dim, embed_dim]. transpose(0, 1)
        # output:               [len_q * bsz, embed_dim] -> [seql_q, bsz, embed_dim]
        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
        input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
                                       input_weights_q.transpose(0, 1))
        input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))

        queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads, head_dim)

        # Input Linear GEMM KV
        # input1: (activations) [seql_k, bsz, embed_dim(1024)]
        # input2: (weights)     [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])
        # output:               [seql_k, bsz, embed_dim*2]
        # GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)

        # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
        # Sequences and heads are combined to make the batch of the Batched GEMM

        if incremental and ('c_k' in incremental_cache and 'c_v' in incremental_cache):
            keys = incremental_cache['c_k']
            values = incremental_cache['c_v']
            keys = keys.view(len_k, bsz * heads, head_dim)
            values = values.view(len_k, bsz * heads, head_dim)
            input_lin_kv_results = torch.stack([keys, values], dim=-2)
        else:
            input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
                                            input_weights_kv.transpose(0, 1))
            input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1),
                                                             input_weights_kv.size(0))

            input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads, 2, head_dim)
            keys = input_lin_kv_results[:, :, 0, :]
            values = input_lin_kv_results[:, :, 1, :]
            if incremental:
                keys = keys.contiguous().view(len_k, bsz, heads * head_dim)
                values = values.contiguous().view(len_k, bsz, heads * head_dim)

                incremental_cache['c_k'] = keys
                incremental_cache['c_v'] = values

                keys = keys.view(len_k, bsz * heads, head_dim)
                values = values.view(len_k, bsz * heads, head_dim)

        # Matmul1 Batched GEMMs
        # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
        # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
        # a separate elementwise operation.
        # Input1: (Queries) [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (Keys)    [seql_k, seqs*heads, head_dim] transpose(0,1)
        # output:           [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
        if queries.is_cuda:
            matmul1_results = torch.empty((queries.size(1), queries.size(0), keys.size(0)), dtype=queries.dtype,
                                          device=queries.device)
            matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0, 1),
                                            keys.transpose(0, 1).transpose(1, 2),
                                            out=matmul1_results, beta=0.0, alpha=scale_t[0])
        else:
            matmul1_results = torch.matmul(queries.transpose(0, 1), keys.transpose(0, 1).transpose(1, 2))
            matmul1_results.mul_(scale_t[0])

        if mask is not None:
            batches, seql_q, seql_k = matmul1_results.size()
            bsz = int(batches / heads)
            matmul1_results = matmul1_results.view(bsz, heads, seql_q, seql_k)
            # after unsqueezing the mask should have size [bsz x 1 x 1 x seql_k]
            matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
            matmul1_results = matmul1_results.view(bsz * heads, seql_q, seql_k)

        if mask_softmax_dropout_cuda and len_k <= 2048 \
                and matmul1_results.type() == 'torch.cuda.HalfTensor' and not double_precision:
            # if False:
            # dropout_results_ref = F.softmax(matmul1_results, dim=-1)
            dropout_mask, softmax_results, dropout_results = mask_softmax_dropout_cuda.forward(is_training, heads,
                                                                                               matmul1_results,
                                                                                               dropout_prob_t[0])
            if not is_training:
                dropout_results = softmax_results  # because the cuda returns empty craps

            # Verification code
            # softmax_results_ref = F.softmax(matmul1_results, dim=-1)
            # #
            # if is_training:
            #     # print(dropout_mask.float().sum(), dropout_mask.numel())
            #     dropout_results_ref = softmax_results_ref * dropout_mask.half() * (1 / (1 - dropout_prob_t[0]))
            # else:
            #     dropout_results_ref = softmax_results_ref
            # #
            # comp = torch.allclose(softmax_results_ref, softmax_results, rtol=1e-03, atol=1e-04)
            # comp = torch.allclose(dropout_results_ref, dropout_results, rtol=1e-03, atol=1e-04)
            # if comp:
            #     print("Forward pass verification passed.")
            # else:
            #     print("ERROR: Forward pass verification failed")
            # Verification done

            ctx.fused_softmax_dropout = True

        else:
            # dtype_ = torch.float64 if double_precision else torch.float32
            # softmax_results = F.softmax(matmul1_results, dim=-1, dtype=dtype_).type_as(matmul1_results)
            if matmul1_results.type() == 'torch.cuda.HalfTensor':
                softmax_results = F.softmax(matmul1_results, dim=-1, dtype=torch.float32).type_as(matmul1_results)
            else:
                softmax_results = F.softmax(matmul1_results, dim=-1)

            # Dropout - is not executed for inference
            if is_training:
                dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1. - dropout_prob_t[0]))
            else:
                dropout_results = softmax_results
                dropout_mask = null_tensor

        # Matmul2 Batched GEMMs
        # The output tensor specification is needed here to specify the non-standard output.
        # Given that pytorch cannot currently perform autograd with an output tensor specified,
        # this requires a backward pass specified.
        # Input1: from_softmax [seqs*heads, seql_q, seql_k]
        # Input2: (values)     [seql_v, seqs*heads, head_dim] transpose(0,1)
        # Output:              [seql_q, seqs*heads, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)

        if queries.is_cuda:
            matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)),
                                          dtype=dropout_results.dtype, device=dropout_results.device)
            torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results.transpose(1, 0))
        else:
            matmul2_results = torch.matmul(dropout_results, values.transpose(0, 1)).transpose(0, 1)

        # view from [len_q, bsz*heads, head_dim] to [len_q, bsz, embed]
        matmul2_results = matmul2_results.contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))

        # Output Linear GEMM
        # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
        # Input2: (weights)     [ embed_dim, embed_dim ] transpose(0,1)
        # Output:               [ seql_q, seqs, embed_dim ]
        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
        outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
                           output_weights.transpose(0, 1))
        outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))

        ctx.save_for_backward(heads_t,
                              scale_t,
                              matmul2_results,
                              dropout_results,
                              softmax_results,
                              input_lin_q_results,
                              input_lin_kv_results,
                              inputs_q,
                              inputs_kv,
                              input_weights_q,
                              input_weights_kv,
                              output_weights,
                              dropout_mask,
                              dropout_prob_t)

        if return_coverage:
            return (outputs, dropout_results)
        else:
            return (outputs, )