Exemple #1
0
def softmax_backward_data(parent, grad_output, output, dim, self):
    """
    A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
    to the torch version detected.
    """

    if is_torch_less_than_1_11:
        return _softmax_backward_data(grad_output, output, parent.dim, self)
    else:
        return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
Exemple #2
0
  def backward(self, grad_output):
    """
    """

    output, = self.saved_tensors
    inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
    return inputGrad, None, None
    def backward(grad_outputs, softmax_results, dropout_prob_t, heads_t, dropout_mask):
        len_key = softmax_results.size(-1)

        if mask_softmax_dropout_cuda is not None and grad_outputs.type() == 'torch.cuda.HalfTensor' \
                and len_key <= 2048:

            softmax_grads = mask_softmax_dropout_cuda.backward_recompute(heads_t[0], grad_outputs, softmax_results,
                                                                         dropout_mask, dropout_prob_t[0])

        else:
            dropout_grads = torch._masked_scale(grad_outputs, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))

            # be careful we overwrite into "softmax_results" memory here
            softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)

        return softmax_grads
    def backward(ctx, output_grads, softmax_grads):
        # def backward(ctx, output_grads):
        """
        :param ctx:
        :param output_grads: gradients w.r.t the outputs
        :param softmax_grads: unncessary except we use the attention weights somewhere
        :return:
        """
        heads_t, \
            scale_t, \
            matmul2_results, \
            dropout_results, \
            softmax_results, \
            qkv, qkv_mm, qkv_r, \
            rpos_r, rpos_mm, \
            rw_head_q, rr_head_q, \
            inputs, pos, r_head_k, \
            input_weights, pos_weights, output_weights, \
            r_i, s_i, r_p, s_p, \
            r_w_bias, r_r_bias, \
            dropout_mask, \
            dropout_prob_t = ctx.saved_tensors

        head_dim = inputs.size(2) // heads_t[0]
        len_q, bsz = inputs.size(0), inputs.size(1)
        len_r = pos.size(0)

        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
        # input_lin_results: [len_q, bsz, heads(16), 3, head_dim(64)]
        # input_lin_results: [len_q, batches=bsz*heads, 3, head_dim]
        qkv = qkv.view(inputs.size(0),
                       inputs.size(1) * heads_t[0], 3, head_dim)
        queries = qkv[:, :, 0, :]
        keys = qkv[:, :, 1, :]
        values = qkv[:, :, 2, :]

        # The tensor is declared before hand to properly slice out query, key, and value grads.
        qkv_grads = torch.empty_like(qkv)
        queries_grads = qkv_grads[:, :, 0, :]
        keys_grads = qkv_grads[:, :, 1, :]
        values_grads = qkv_grads[:, :, 2, :]

        # Output Linear Projection
        o_input = matmul2_results

        # output_lin_grads, output_weights_grads, output_biases_grads, r_o_grads, s_o_grads \
        #     = mm.backward(output_grads, o_input, o_r, o_mm, output_weights, r_o, s_o)
        output_lin_grads = torch.mm(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)), output_weights)
        output_lin_grads = output_lin_grads.view(output_grads.size(0),
                                                 output_grads.size(1),
                                                 output_weights.size(1))

        output_weights_grads = torch.mm(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)).transpose(0, 1),
            matmul2_results.view(
                matmul2_results.size(0) * matmul2_results.size(1),
                matmul2_results.size(2)))
        output_lin_grads = output_lin_grads.view(inputs.size(0),
                                                 inputs.size(1) * heads_t[0],
                                                 head_dim).transpose(0, 1)

        output_biases_grads = torch.sum(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)), 0)

        # Matmul2 - DGRAD1
        # Input1: (data grads)  [len_q, bsz*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [bsz*heads, len_q, seql_k]
        # GEMM: Per batch: ( len_q x head_dim ) x ( head_dim x seql_k ) = ( len_q x seql_k )
        matmul2_dgrad1 = torch.bmm(output_lin_grads,
                                   values.transpose(0, 1).transpose(1, 2))
        # Matmul2 - DGRAD2
        # Input1: (data grads)  [len_q, bsz*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [bsz*heads, len_q, seql_k]
        # GEMM: Per batch: ( len_q x head_dim ) x ( head_dim x seql_k ) = ( len_q x seql_k )
        torch.bmm(dropout_results.transpose(1, 2),
                  output_lin_grads,
                  out=values_grads.transpose(0, 1))

        # print("Reached here")

        # Mask and Scaling for Dropout (not a publically documented op)
        if dropout_prob_t[0] > 0.0:
            dropout_grads = torch._masked_scale(
                matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
        else:
            dropout_grads = matmul2_dgrad1

        # Softmax Grad (not a publically documented op)
        softmax_grads = torch._softmax_backward_data(dropout_grads,
                                                     softmax_results, -1,
                                                     softmax_results)
        attn_score_grads = softmax_grads
        # the grads are evenly distributed to AC and BD
        matmul_ac_grads = attn_score_grads

        # Matmul1 - DGRAD1
        # Input1: (data grads)  [bsz*heads, len_q, seql_k]
        # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1)
        # Output:               [bsz*heads, len_q, head_dim] transpose(0,1)
        # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = ( len_q x head_dim )
        torch.baddbmm(queries_grads.transpose(0, 1),
                      matmul_ac_grads,
                      keys.transpose(0, 1),
                      out=queries_grads.transpose(0, 1),
                      beta=0.0,
                      alpha=scale_t[0])

        queries_grads_ac = queries_grads
        r_w_bias_grads = torch.sum(queries_grads_ac.view(
            len_q, bsz, heads_t[0], -1),
                                   dim=[0, 1])  # heads * head_dim

        matmul_bd_grads = attn_score_grads

        if len_r > len_q:  # if we cut off the BDs from before, then put the zero gradients behind
            grad_cut = matmul_bd_grads.new_zeros(
                (matmul_bd_grads.size(0), matmul_bd_grads.size(1),
                 len_r - len_q))
            matmul_bd_grads = torch.cat([matmul_bd_grads, grad_cut], dim=-1)

        # backprop through the shifting
        matmul_bd_grads = RelativeShift.backward(matmul_bd_grads, True, False)

        # Matmul1 - DGRAD1
        # Input1: (matmul_bd_grads)  [bsz*heads, len_q, seql_k]
        # Input2: (r_head_k) [len_q, bsz*heads, head_dim] transpose(0,1)
        # Output:               [bsz*heads, len_q, head_dim] transpose(0,1)
        # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = ( len_q x head_dim )
        queries_grads_bd = queries_grads.new_empty(*queries_grads.size())
        torch.baddbmm(queries_grads_bd.transpose(0, 1),
                      matmul_bd_grads,
                      r_head_k.transpose(0, 1),
                      out=queries_grads_bd.transpose(0, 1),
                      beta=0.0,
                      alpha=scale_t[0])

        # len_q x batch*heads x d_head
        r_r_bias_grads = torch.sum(queries_grads_bd.view(
            len_q, bsz, heads_t[0], -1),
                                   dim=[0, 1])

        # add the gradients from bd to queries
        queries_grads.add_(queries_grads_bd)

        # # MatmulAC - DGAD2
        # Input1: (data grads)  [bsz*heads, len_q, seql_k] transpose(1,2)
        # Input2: (rw_head_q) [bsz*heads, head_dim, len_q] transpose(0,1)
        # Output:               [seql_k, bsz*heads, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_k x len_q ) x ( len_q x head_dim ) = ( seql_k x head_dim )
        torch.baddbmm(keys_grads.transpose(0, 1),
                      matmul_ac_grads.transpose(1, 2),
                      rw_head_q.transpose(0, 1),
                      out=keys_grads.transpose(0, 1),
                      beta=0.0,
                      alpha=scale_t[0])

        # MatmulBD - DGRAD2
        # Input1: (data grads)  [bsz*heads, len_q, len_r] transpose(1,2)
        # Input2: (rr_head_q) [len_q, bsz*heads, head_dim] transpose(0,1)
        # Output:  r_head_k  [len_r, bsz*heads, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_k x len_q ) x ( len_q x head_dim ) = ( seql_k x head_dim )
        r_head_k_grad = r_head_k.new_empty((len_r, bsz * heads_t[0], head_dim))
        # rr_head_q = queries.view(len_q, bsz, heads_t[0], head_dim) + r_r_bias  #
        # rr_head_q = rr_head_q.view(len_q, bsz * heads_t[0], head_dim)
        torch.baddbmm(r_head_k_grad.transpose(0, 1),
                      matmul_bd_grads.transpose(1, 2).contiguous(),
                      rr_head_q.transpose(0, 1),
                      out=r_head_k_grad.transpose(0, 1),
                      beta=0.0,
                      alpha=scale_t[0])
        # r_head_k_grad = torch.matmul(matmul_bd_grads.transpose(1, 2), rr_head_q.transpose(0, 1))

        r_head_k_grad = r_head_k_grad.view(len_r, bsz, heads_t[0] * head_dim)
        # Input Linear GEMM - DGRAD
        # input1: (data grads) [len_q, bsz, 3*embed_dim(3072)]
        # input2: (weights)    [embed_dim*3 (3072), embed_dim (1024)]
        # output:              [len_q, bsz, embed_dim]
        # GEMM: ( (len_q*bsz) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (len_q*bsz x embed_dim)
        qkv_grads = qkv_grads.view(inputs.size(0), inputs.size(1),
                                   heads_t[0] * 3 * head_dim)

        input_grads, input_weights_grads, input_biases_grads, r_i_grads, s_i_grads = \
            mm.backward(qkv_grads, inputs, qkv_r, qkv_mm, input_weights, r_i, s_i)

        _, pos_weights_grads, pos_biases_grads, r_p_grads, s_p_grads = \
            mm.backward(r_head_k_grad, pos, rpos_r, rpos_mm, pos_weights, r_p, s_p, need_grad_x=False)

        return input_grads, None, None, None, None, None, \
               input_weights_grads, output_weights_grads, pos_weights_grads, \
               input_biases_grads, output_biases_grads, pos_biases_grads, \
               r_i_grads, s_i_grads, r_p_grads, s_p_grads, \
               r_w_bias_grads, r_r_bias_grads, \
               None, None, None, None, None
    def backward(ctx, output_grads, softmax_grads):

        heads_t, scale_t, matmul2_results, dropout_results, softmax_results \
            , input_lin_q_results, input_lin_kv_results \
            , inputs_q, inputs_kv \
            , input_weights_q, input_weights_kv, output_weights \
            , dropout_mask, dropout_prob_t \
            = ctx.saved_tensors

        head_dim = inputs_q.size(2) // heads_t[0]

        # Slice out k,v from one big Input Linear output (should only impact meta data, no copies!)
        # Batch sizes and heads are combined to make the batch of the Batched GEMM
        # input_lin_kv_results: [seql_k, bsz, heads(16), 2, head_dim(64)]
        # input_lin_kv_results: [seql_k, batches=bsz*heads, 2, head_dim]
        queries = input_lin_q_results.view(inputs_q.size(0),
                                           inputs_q.size(1) * heads_t[0],
                                           head_dim)
        input_lin_kv_results = input_lin_kv_results.view(
            inputs_kv.size(0),
            inputs_kv.size(1) * heads_t[0], 2, head_dim)
        keys = input_lin_kv_results[:, :, 0, :]
        values = input_lin_kv_results[:, :, 1, :]

        # Slice out k,v from one big set of gradients entering the input linear's bprop
        # (should only impact meta data, no copies!)
        # The gradients are identical in size to the Input Linear outputs.
        # The tensor is declared before hand to properly slice out query, key, and value grads.
        input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)
        queries_grads = torch.empty_like(queries)
        keys_grads = input_lin_kv_results_grads[:, :, 0, :]
        values_grads = input_lin_kv_results_grads[:, :, 1, :]

        # Output Linear GEMM - DGRAD
        # Input1: (data grads)  [seql_q, bsz, embed_dim=heads*head_dim]
        # Input2: (weights)     [ embed_dim, embed_dim ]
        # Output:               [ seql_q, seqs, embed_dim ]
        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
        output_lin_grads = torch.mm(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)), output_weights)
        output_lin_grads = output_lin_grads.view(output_grads.size(0),
                                                 output_grads.size(1),
                                                 output_weights.size(1))
        # Output Linear GEMM - WGRAD
        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
        # Input2: (activations) [seql_q*seqs, embed_dim ]
        # Output:               [ seql_q, seqs, embed_dim ]
        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
        output_weight_grads = torch.mm(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)).transpose(0, 1),
            matmul2_results.view(
                matmul2_results.size(0) * matmul2_results.size(1),
                matmul2_results.size(2)))
        output_lin_grads = output_lin_grads.view(
            output_grads.size(0),
            output_grads.size(1) * heads_t[0], head_dim).transpose(0, 1)

        # Matmul2 - DGRAD1
        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
        matmul2_dgrad1 = torch.bmm(output_lin_grads,
                                   values.transpose(0, 1).transpose(1, 2))
        # Matmul2 - DGRAD2
        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
        values_grads = torch.bmm(dropout_results.transpose(1, 2),
                                 output_lin_grads,
                                 out=values_grads.transpose(0, 1))

        # Mask and Scaling for Dropout (not a publically documented op)
        dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask,
                                            1.0 / (1.0 - dropout_prob_t[0]))

        # Softmax Grad (not a publically documented op)
        softmax_grads = torch._softmax_backward_data(dropout_grads,
                                                     softmax_results, -1,
                                                     softmax_results)

        # Matmul1 - DGRAD1
        # Input1: (data grads)  [seqs*heads, seql_q, seql_k]
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
        queries_grads = torch.baddbmm(queries_grads.transpose(0, 1),
                                      softmax_grads,
                                      keys.transpose(0, 1),
                                      out=queries_grads.transpose(0, 1),
                                      beta=0.0,
                                      alpha=scale_t[0])
        # Matmul1 - DGRAD2
        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)
        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
        torch.baddbmm(keys_grads.transpose(0, 1),
                      softmax_grads.transpose(1, 2),
                      queries.transpose(0, 1),
                      out=keys_grads.transpose(0, 1),
                      beta=0.0,
                      alpha=scale_t[0])

        # Input Q Linear GEMM - DGRAD
        # input1: (data grads) [seql_q, seqs, embed_dim(1024)]
        # input2: (weights)    [embed_dim (1024), embed_dim (1024)]
        # output:              [seql_q, seqs, embed_dim]
        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
        queries_grads = queries_grads.transpose(0, 1).view(
            inputs_q.size(0) * inputs_q.size(1), heads_t[0] * head_dim)
        input_q_grads = torch.mm(queries_grads, input_weights_q)
        input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1),
                                           inputs_q.size(2))
        # Input KV Linear GEMM - DGRAD
        # input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]
        # input2: (weights)    [embed_dim*2 (2048), embed_dim (1024)]
        # output:              [seql_k, seqs, embed_dim]
        # GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)
        # the elements of values and query grads are already stored in (shared) query_grads and values_grads
        input_lin_kv_results_grads = input_lin_kv_results_grads.view(
            inputs_kv.size(0) * inputs_kv.size(1), heads_t[0] * 2 * head_dim)
        input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)
        input_kv_grads = input_kv_grads.view(inputs_kv.size(0),
                                             inputs_kv.size(1),
                                             inputs_kv.size(2))
        # Input Q Linear GEMM - WGRAD
        # input1: (data grads)  [seql_q*seqs, embed_dim(1024)]
        # input2: (activations) [seql_q*seqs, embed_dim(1024)]
        # output:               [embed_dim, embed_dim]
        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)
        input_weight_q_grads = torch.mm(
            queries_grads.transpose(0, 1),
            inputs_q.view(
                inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)))
        # Input KV Linear GEMM - WGRAD
        # input1: (data grads)  [seql_k*seqs, 2*embed_dim(2048)]
        # input2: (activations) [seql_k*seqs, embed_dim(1024)]
        # output:               [2*embed_dim, embed_dim]
        # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
        input_weight_kv_grads = torch.mm(
            input_lin_kv_results_grads.transpose(0, 1),
            inputs_kv.view(
                inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)))

        return None, None, None \
            , input_q_grads, input_kv_grads \
            , input_weight_q_grads, input_weight_kv_grads, output_weight_grads \
            , None, None, None, None
Exemple #6
0
    def backward(ctx, *output_grads):
        """
        :param ctx:
        :param output_grads: gradients w.r.t the outputs
        :return:
        """

        if not ctx.recompute:

            heads_t, \
            scale_t, \
            matmul2_results, \
            dropout_results, \
            softmax_results, \
            input_lin_results, pos_lin_results, \
            rw_head_q, rr_head_q, \
            inputs, pos, r_head_k, \
            input_weights, pos_weights, \
            output_weights, \
            dropout_mask, nan_mask, \
            dropout_prob_t = ctx.saved_tensors

        else:
            heads_t, \
            scale_t, \
            inputs, pos, r_head_k, \
            input_weights, pos_weights, output_weights, \
            input_biases, pos_biases, output_biases, \
            r_w_bias, r_r_bias, \
            dropout_mask, nan_mask, pad_mask, \
            dropout_prob_t = ctx.saved_tensors

        learnable_pos = ctx.learnable_pos
        if ctx.return_coverage:
            output_grads, softmax_grads = output_grads
        else:
            output_grads = output_grads[0]

        head_dim = inputs.size(2) // heads_t[0]
        len_q, bsz = inputs.size(0), inputs.size(1)
        len_k = len_q

        len_r = pos.size(0)

        if ctx.fused_all:  # only applicable for learnable position and len_k <= 2048

            if ctx.recompute:
                input_grads, pos_grads, \
                input_weight_grads, \
                input_bias_grads, \
                output_weight_grads, \
                output_bias_grads, \
                r_w_bias_grads, r_r_bias_grads = rel_self_attn_cuda.backward_recompute(
                    heads_t[0], output_grads,
                    inputs, pos,
                    input_weights, output_weights,
                    input_biases, output_biases,
                    r_w_bias, r_r_bias,
                    dropout_mask, pad_mask, dropout_prob_t[0])
            else:
                input_grads, pos_grads, \
                input_weight_grads, \
                input_bias_grads, \
                output_weight_grads, \
                output_bias_grads, \
                r_w_bias_grads, r_r_bias_grads = rel_self_attn_cuda.backward(
                    heads_t[0], output_grads, matmul2_results,
                    dropout_results, softmax_results,
                    input_lin_results, rw_head_q, rr_head_q,
                    inputs, pos, input_weights, output_weights,
                    dropout_mask, dropout_prob_t[0])
            pos_weight_grads = None
            pos_bias_grads = None

            return input_grads, pos_grads, None, None, None, input_weight_grads, \
                   output_weight_grads, pos_weight_grads, \
                   input_bias_grads, output_bias_grads, pos_bias_grads, r_w_bias_grads, r_r_bias_grads, \
                   None, None, None, None, None, None, None, None

        if ctx.recompute:
            heads = heads_t[0]

            # Recomputing the activations in the forward pass here
            input_lin_results = torch.addmm(
                input_biases,
                inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
                input_weights.transpose(0, 1),
                beta=1.,
                alpha=1.)

            input_lin_results = input_lin_results.view(inputs.size(0),
                                                       inputs.size(1),
                                                       input_weights.size(0))
            if not learnable_pos:
                pos_lin_results = torch.addmm(pos_biases,
                                              pos.view(
                                                  pos.size(0) * pos.size(1),
                                                  pos.size(2)),
                                              pos_weights.transpose(0, 1),
                                              beta=1.,
                                              alpha=1.)

                pos_lin_results = pos_lin_results.view(pos.size(0),
                                                       pos.size(1),
                                                       pos_weights.size(0))

                r_head_k = pos_lin_results.view(pos.size(0), bsz * heads,
                                                head_dim)  # T x BxH x D
            else:
                # pos_lin_results = pos.view(pos.size(0), bsz * heads, head_dim)  # T x BxH x D
                # r_head_k = pos_lin_results
                pos_lin_results = None
                r_head_k = None

            input_lin_results = input_lin_results.view(inputs.size(0),
                                                       inputs.size(1) * heads,
                                                       3, head_dim)
            queries = input_lin_results[:, :, 0, :]
            keys = input_lin_results[:, :, 1, :]
            values = input_lin_results[:, :, 2, :]

            rw_head_q = queries.view(len_q, bsz, heads, head_dim) + r_w_bias  #
            rw_head_q = rw_head_q.view(len_q, bsz * heads, head_dim)

            matmul_ac = torch.empty(
                (bsz * heads, queries.size(0), keys.size(0)),
                dtype=queries.dtype,
                device=rw_head_q.device)
            matmul_ac.baddbmm_(rw_head_q.transpose(0, 1),
                               keys.transpose(0, 1).transpose(1, 2),
                               beta=0.0,
                               alpha=scale_t[0])

            rr_head_q = queries.view(len_q, bsz, heads, head_dim) + r_r_bias
            rr_head_q = rr_head_q.view(len_q, bsz * heads, head_dim)

            if not learnable_pos:
                matmul_bd = torch.empty((bsz * heads, queries.size(0), len_r),
                                        dtype=queries.dtype,
                                        device=rw_head_q.device)

                matmul_bd.baddbmm_(rr_head_q.transpose(0, 1),
                                   r_head_k.transpose(0, 1).transpose(1, 2),
                                   beta=0.0,
                                   alpha=scale_t[0])

                matmul_bd = RelativeShift.forward(matmul_bd, True, False)
                matmul_bd = matmul_bd[:, :, :len_k]
                attn_score = matmul_ac + matmul_bd

            else:
                matmul_ac.transpose(0, 1).baddbmm_(rr_head_q,
                                                   pos.transpose(1, 2),
                                                   beta=1.0,
                                                   alpha=scale_t[0])
                attn_score = matmul_ac

            if pad_mask is not None:
                attn_score.view(bsz, heads, len_q,
                                len_k).masked_fill_(pad_mask, float('-inf'))

            softmax_results = F.softmax(attn_score, dim=-1).type_as(attn_score)
            del attn_score

            pinv = 1.0 / (1.0 - dropout_prob_t[0])
            dropout_results = softmax_results * dropout_mask * pinv
            matmul2_results = torch.bmm(dropout_results,
                                        values.transpose(0,
                                                         1)).transpose(0, 1)
            matmul2_results = matmul2_results.contiguous().view(
                inputs.size(0), inputs.size(1), inputs.size(2))

        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
        # input_lin_results: [len_q, bsz, heads(16), 3, head_dim(64)]
        # input_lin_results: [len_q, batches=bsz*heads, 3, head_dim]
        input_lin_results = input_lin_results.view(inputs.size(0),
                                                   inputs.size(1) * heads_t[0],
                                                   3, head_dim)
        queries = input_lin_results[:, :, 0, :]
        keys = input_lin_results[:, :, 1, :]
        values = input_lin_results[:, :, 2, :]

        # The tensor is declared before hand to properly slice out query, key, and value grads.
        input_lin_results_grads = torch.empty_like(input_lin_results)
        queries_grads = input_lin_results_grads[:, :, 0, :]
        keys_grads = input_lin_results_grads[:, :, 1, :]
        values_grads = input_lin_results_grads[:, :, 2, :]

        # Output Linear GEMM - DGRAD
        # Input1: (data grads)  [len_q, bsz, embed_dim=heads*head_dim]
        # Input2: (weights)     [ embed_dim, embed_dim ]
        # Output:               [ len_q, bsz, embed_dim ]
        # GEMM: ( len_q*bsz x embed_dim ) x ( embed_dim x embed_dim ) = ( len_q*bsz x embed_dim )
        output_lin_grads = torch.mm(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)), output_weights)
        output_lin_grads = output_lin_grads.view(output_grads.size(0),
                                                 output_grads.size(1),
                                                 output_weights.size(1))
        # Output Linear GEMM - WGRAD
        # Input1: (data grads)  [len_q*bsz, embed_dim=heads*head_dim] transpose(0,1)
        # Input2: (activations) [len_q*bsz, embed_dim ]
        # Output:               [ len_q, bsz, embed_dim ]
        # GEMM: ( embed_dim x len_q*bsz ) x ( len_q*bsz x embed_dim ) = ( embed_dim x embed_dim )
        output_weight_grads = torch.mm(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)).transpose(0, 1),
            matmul2_results.view(
                matmul2_results.size(0) * matmul2_results.size(1),
                matmul2_results.size(2)))
        output_lin_grads = output_lin_grads.view(inputs.size(0),
                                                 inputs.size(1) * heads_t[0],
                                                 head_dim).transpose(0, 1)

        output_bias_grads = torch.sum(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)), 0)

        # Matmul2 - DGRAD1
        # Input1: (data grads)  [bsz*heads, len_q,  head_dim]
        # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [bsz*heads, len_q, seql_k]
        # GEMM: Per batch: ( len_q x head_dim ) x ( head_dim x seql_k ) = ( len_q x seql_k )
        matmul2_dgrad1 = torch.bmm(output_lin_grads,
                                   values.transpose(0, 1).transpose(1, 2))
        # Matmul2 - DGRAD2
        # Input2: (data grads)  [bsz*heads, len_q,  head_dim]
        # Input1: (activations) [bsz*heads, len_q, len_k] transpose(1,2)
        # Output:               [bsz*heads, len_k, head_dim]
        # GEMM: Per batch: ( len_k x len_q ) x ( len_q x head_dim ) = ( len_k x head_dim )
        torch.bmm(dropout_results.transpose(1, 2),
                  output_lin_grads,
                  out=values_grads.transpose(0, 1))

        # Input1: (data grads)  [bsz*heads, len_q,  head_dim].transpose(0, 1)
        # Input2: (rpositions) [len_q, len_k, head_dim].transpose(1,2)
        # Output:               [bsz*heads, len_q, seql_k].transpose(0, 1)
        # torch.baddbmm(matmul2_dgrad1.transpose(0, 1), output_lin_grads.transpose(0, 1), pos.transpose(1, 2),
        #               beta=1.0, alpha=1.0, out=matmul2_dgrad1.transpose(0, 1))
        # Input2: (data grads)  [bsz*heads, len_q,  head_dim].transpose(0, 1)
        # Input1: (activations) [bsz*heads, len_q, len_k] transpose(0,1).transpose(1,2)
        # Output:               [len_q, len_k, head_dim]
        # pos_grads = torch.bmm(dropout_results.transpose(0, 1).transpose(1, 2), output_lin_grads.transpose(0, 1))

        if mask_softmax_dropout_cuda is not None and len_k <= 2048 \
                and matmul2_dgrad1.type() == 'torch.cuda.HalfTensor' and not ctx.double_precision:

            # verification code
            # matmul2_dgrad1_ref = matmul2_dgrad1.clone()
            # softmax_results_ref = softmax_results.clone()
            # dropout_grads = torch._masked_scale(matmul2_dgrad1_ref, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
            # softmax_grads_ref = torch._softmax_backward_data(dropout_grads,
            #                                                  softmax_results_ref, -1, softmax_results_ref)

            # softmax_grads_ref = mask_softmax_dropout_cuda.backward(heads_t[0], matmul2_dgrad1_ref, softmax_results,
            #                                                              dropout_mask, dropout_prob_t[0])

            # note this function doesn't really do recompute
            # because if we don't store softmax results, we have to store matmul2_dgrad1
            # so the only difference is that we allocated more memory, but the memory kept for backward is the same
            softmax_grads = mask_softmax_dropout_cuda.backward_recompute(
                heads_t[0], matmul2_dgrad1, softmax_results, dropout_mask,
                dropout_prob_t[0])

            # verification code (next)
            # comp = torch.allclose(softmax_grads_ref, softmax_grads, rtol=1e-03, atol=1e-04)
            # if not comp:
            #     print("ERROR: Gradients mismatched.")
            #     print(softmax_grads_ref - softmax_grads)
            # else:
            #     print("Gradients matched.")

        else:
            # Mask and Scaling for Dropout (not a publically documented op)
            if dropout_prob_t[0] > 0.0:
                dropout_grads = torch._masked_scale(
                    matmul2_dgrad1, dropout_mask,
                    1.0 / (1.0 - dropout_prob_t[0]))
            else:
                dropout_grads = matmul2_dgrad1

            # Softmax Grad (not a publically documented op)
            softmax_grads = torch._softmax_backward_data(
                dropout_grads, softmax_results, -1, softmax_results)

        attn_score_grads = softmax_grads
        # the grads are evenly distributed to AC and BD
        matmul_ac_grads = attn_score_grads

        # Matmul1 - DGRAD1
        # Input1: (data grads)  [bsz*heads, len_q, seql_k]
        # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1)
        # Output:               [bsz*heads, len_q, head_dim] transpose(0,1)
        # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = ( len_q x head_dim )
        torch.baddbmm(queries_grads.transpose(0, 1),
                      matmul_ac_grads,
                      keys.transpose(0, 1),
                      out=queries_grads.transpose(0, 1),
                      beta=0.0,
                      alpha=scale_t[0])

        queries_grads_ac = queries_grads
        r_w_bias_grads = torch.sum(queries_grads_ac.view(
            len_q, bsz, heads_t[0], -1),
                                   dim=[0, 1])  # heads * head_dim

        matmul_bd_grads = attn_score_grads

        if not learnable_pos:

            if len_r > len_q:  # if we cut off the BDs from before, then put the zero gradients behind
                grad_cut = matmul_bd_grads.new_zeros(
                    (matmul_bd_grads.size(0), matmul_bd_grads.size(1),
                     len_r - len_q))
                matmul_bd_grads = torch.cat([matmul_bd_grads, grad_cut],
                                            dim=-1)

            # backprop through the shifting
            matmul_bd_grads = RelativeShift.backward(matmul_bd_grads, True,
                                                     False)

            # MatmulBD - DGRAD1
            # Input1: (matmul_bd_grads)  [bsz*heads, len_q, seql_k]
            # Input2: (r_head_k) [len_q, bsz*heads, head_dim] transpose(0,1)
            # Output:               [bsz*heads, len_q, head_dim] transpose(0,1)
            # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = ( len_q x head_dim )
            queries_grads_bd = queries_grads.new_empty(*queries_grads.size())
            torch.baddbmm(queries_grads_bd.transpose(0, 1),
                          matmul_bd_grads,
                          r_head_k.transpose(0, 1),
                          out=queries_grads_bd.transpose(0, 1),
                          beta=0.0,
                          alpha=scale_t[0])
        else:
            # MatmulBD - DGRAD1
            # Input1: (matmul_bd_grads)  [bsz*heads, len_q, len_k] transpose(0,1)
            # Input2: (pos) [len_q, len_k, head_dim]
            # Output:               [len_q, bsz*heads, head_dim]
            # GEMM: Per batch: ( bsz*heads x len_k ) x ( len_k x head_dim ) = ( bsz*heads x head_dim )
            queries_grads_bd = queries_grads.new_empty(*queries_grads.size())
            torch.baddbmm(queries_grads_bd,
                          matmul_bd_grads.transpose(0, 1),
                          pos,
                          out=queries_grads_bd,
                          beta=0.0,
                          alpha=scale_t[0])
            # queries_grads_bd = torch.bmm(matmul_bd_grads.transpose(0, 1), pos).mul_(scale_t[0])

        # len_q x batch*heads x d_head
        r_r_bias_grads = torch.sum(queries_grads_bd.view(
            len_q, bsz, heads_t[0], -1),
                                   dim=[0, 1])
        # add the gradients from bd to queries
        queries_grads.add_(queries_grads_bd)

        # # MatmulAC - DGAD2
        # Input1: (data grads)  [bsz*heads, len_q, seql_k] transpose(1,2)
        # Input2: (rw_head_q) [bsz*heads, head_dim, len_q] transpose(0,1)
        # Output:               [seql_k, bsz*heads, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_k x len_q ) x ( len_q x head_dim ) = ( seql_k x head_dim )
        torch.baddbmm(keys_grads.transpose(0, 1),
                      matmul_ac_grads.transpose(1, 2),
                      rw_head_q.transpose(0, 1),
                      out=keys_grads.transpose(0, 1),
                      beta=0.0,
                      alpha=scale_t[0])

        if not learnable_pos:
            # MatmulBD - DGRAD2
            # Input1: (data grads)  [bsz*heads, len_q, len_r] transpose(1,2)
            # Input2: (rr_head_q) [len_q, bsz*heads, head_dim] transpose(0,1)
            # Output:  r_head_k  [len_r, bsz*heads, head_dim] transpose(0,1)
            # GEMM: Per batch: ( seql_k x len_q ) x ( len_q x head_dim ) = ( seql_k x head_dim )
            r_head_k_grad = r_head_k.new_empty(
                (len_r, bsz * heads_t[0], head_dim))
            torch.baddbmm(r_head_k_grad.transpose(0, 1),
                          matmul_bd_grads.transpose(1, 2).contiguous(),
                          rr_head_q.transpose(0, 1),
                          out=r_head_k_grad.transpose(0, 1),
                          beta=0.0,
                          alpha=scale_t[0])

            r_head_k_grad = r_head_k_grad.view(len_r, bsz, heads_t[0], head_dim). \
                view(len_r * bsz, heads_t[0] * head_dim)

            pos_weight_grads = torch.mm(
                r_head_k_grad.transpose(0, 1),
                pos.view(pos.size(0) * pos.size(1), pos.size(2)))

            pos_bias_grads = torch.sum(r_head_k_grad, 0)
            pos_grads = None
        else:
            pos_weight_grads, pos_bias_grads = None, None
            pos_grads = torch.empty_like(pos)
            # MatmulBD - DGRAD2
            # Input1: (data grads)  [bsz*heads, len_q, len_k] transpose(0,1),(1,2) -> [len_q, len_k, bsz*heads]
            # Input2: (rr_head_q) [len_q, bsz*heads, head_dim]
            # Output:  pos_grads  [len_q, len_k, head_dim]
            # GEMM: Per batch: ( len_k x bsz ) x ( bsz x head_dim ) = ( len_k x head_dim )
            torch.baddbmm(pos_grads,
                          matmul_bd_grads.transpose(0, 1).transpose(
                              1, 2).contiguous(),
                          rr_head_q,
                          out=pos_grads,
                          beta=0.0,
                          alpha=scale_t[0])

        # Input Linear GEMM - DGRAD
        # input1: (data grads) [len_q, bsz, 3*embed_dim(3072)]
        # input2: (weights)    [embed_dim*3 (3072), embed_dim (1024)]
        # output:              [len_q, bsz, embed_dim]
        # GEMM: ( (len_q*bsz) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (len_q*bsz x embed_dim)
        input_lin_results_grads = input_lin_results_grads.view(
            inputs.size(0) * inputs.size(1), heads_t[0] * 3 * head_dim)
        input_grads = torch.mm(input_lin_results_grads, input_weights)
        input_grads = input_grads.view(inputs.size(0), inputs.size(1),
                                       inputs.size(2))
        # Input Linear GEMM - WGRAD
        # input1: (data grads)  [len_q*bsz, 3*embed_dim(3072)]
        # input2: (activations) [len_q*bsz, embed_dim(1024)]
        # output:               [3*embed_dim, embed_dim]
        # GEMM: ( 3*embed_dim x len_q*bsz ) x ( len_q*bsz x embed_dim ) = (3*embed_dim x embed_dim)
        input_weight_grads = torch.mm(
            input_lin_results_grads.transpose(0, 1),
            inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)))

        input_bias_grads = torch.sum(input_lin_results_grads, 0)

        return input_grads, pos_grads, None, None, None, input_weight_grads, output_weight_grads, pos_weight_grads, \
               input_bias_grads, output_bias_grads, pos_bias_grads, r_w_bias_grads, r_r_bias_grads, \
               None, None, None, None, None, None, None, None
Exemple #7
0
    def backward(ctx, output_grads):
        use_biases_t,                                                   \
        heads_t,                                                        \
        scale_t,                                                        \
        matmul2_results,                                                \
        dropout_results,                                                \
        softmax_results,                                                \
        input_lin_results,                                              \
        inputs,                                                         \
        input_weights,                                                  \
        output_weights,                                                 \
        dropout_mask,                                                   \
        dropout_prob_t          = ctx.saved_tensors

        head_dim                = inputs.size(2) // heads_t[0]

        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
        # Sequences and heads are combined to make the batch of the Batched GEMM
        # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
        # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
        input_lin_results       = input_lin_results.view(inputs.size(0), inputs.size(1)*heads_t[0], 3, head_dim)
        queries                 = input_lin_results[:,:,0,:]
        keys                    = input_lin_results[:,:,1,:]
        values                  = input_lin_results[:,:,2,:]

        # Slice out q,k,v from one big set of gradients entering the input linear's bprop  (should only impact meta data, no copies!)
        # The gradients are identical in size to the Input Linear outputs.
        # The tensor is declared before hand to properly slice out query, key, and value grads.
        input_lin_results_grads = torch.empty_like(input_lin_results)
        queries_grads           = input_lin_results_grads[:,:,0,:]
        keys_grads              = input_lin_results_grads[:,:,1,:]
        values_grads            = input_lin_results_grads[:,:,2,:]

        # Output Linear GEMM - DGRAD
        # Input1: (data grads)  [seql_q, seqs, embed_dim=heads*head_dim]
        # Input2: (weights)     [ embed_dim, embed_dim ]
        # Output:               [ seql_q, seqs, embed_dim ]
        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
        output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)
        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
        # Output Linear GEMM - WGRAD
        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
        # Input2: (activations) [seql_q*seqs, embed_dim ]
        # Output:               [ seql_q, seqs, embed_dim ]
        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
        output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
                                       matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
        output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)

        if use_biases_t[0]:
            output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
        else:
            output_bias_grads = None

        # Matmul2 - DGRAD1
        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
        matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))
        # Matmul2 - DGRAD2
        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
        values_grads   = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))

        # Mask and Scaling for Dropout (not a publically documented op)
        dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0])

        # Softmax Grad (not a publically documented op)
        softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)

        # Matmul1 - DGRAD1
        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] 
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
        queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
                                      out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
        # Matmul1 - DGRAD2
        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)
        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
        keys_grads    = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
                                      out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])

        # Input Linear GEMM - DGRAD
        # input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)]
        # input2: (weights)    [embed_dim*3 (3072), embed_dim (1024)] 
        # output:              [seql_q, seqs, embed_dim]
        # GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
        input_lin_results_grads = input_lin_results_grads.view(inputs.size(0)*inputs.size(1), heads_t[0]*3*head_dim)
        input_grads = torch.mm(input_lin_results_grads, input_weights)
        input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2))
        # Input Linear GEMM - WGRAD
        # input1: (data grads)  [seql_q*seqs, 3*embed_dim(3072)]
        # input2: (activations) [seql_q*seqs, embed_dim(1024)] 
        # output:               [3*embed_dim, embed_dim]
        # GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)
        input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))

        if use_biases_t[0]:
            input_bias_grads = torch.sum(input_lin_results_grads, 0)
        else:
            input_bias_grads = None

        return None, None, None, None,                   \
               input_grads,                              \
               input_weight_grads, output_weight_grads,  \
               input_bias_grads, output_bias_grads,      \
               None, None
    def backward(ctx, *output_grads):

        incremental = ctx.incremental
        len_q = ctx.len_q
        len_key = ctx.len_k

        if ctx.return_coverage:
            output_grads, coverage_grads = output_grads
        else:
            output_grads = output_grads[0]

        heads_t, scale_t, matmul2_results, dropout_results, softmax_results, \
        input_lin_q_results, input_lin_kv_results, \
        inputs_q, inputs_kv, \
        input_weights_q, input_weights_kv, output_weights, \
        dropout_mask, dropout_prob_t \
            = ctx.saved_tensors

        head_dim = inputs_q.size(2) // heads_t[0]
        bsz = inputs_q.size(1)

        if ctx.fused_all:
            assert encdec_multihead_attn_cuda is not None and len_key <= 2048

            input_q_grads, \
            input_kv_grads, \
            input_weight_q_grads, \
            input_weight_kv_grads, \
            output_weight_grads = encdec_multihead_attn_cuda.backward(heads_t[0], output_grads, matmul2_results,
                                                                      dropout_results,
                                                                      softmax_results, input_lin_q_results,
                                                                      input_lin_kv_results,
                                                                      inputs_q, inputs_kv, input_weights_q,
                                                                      input_weights_kv,
                                                                      output_weights, dropout_mask,
                                                                      dropout_prob_t[0])

            return None, None, None, \
                   input_q_grads, input_kv_grads, \
                   input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \
                   None, None, None, None, None, None

        # Slice out k,v from one big Input Linear output (should only impact meta data, no copies!)
        # Batch sizes and heads are combined to make the batch of the Batched GEMM
        # input_lin_kv_results: [seql_k, bsz, heads(16), 2, head_dim(64)]
        # input_lin_kv_results: [seql_k, batches=bsz*heads, 2, head_dim]
        queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads_t[0], head_dim)
        input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads_t[0], 2, head_dim)
        keys = input_lin_kv_results[:, :, 0, :]
        values = input_lin_kv_results[:, :, 1, :]

        # Slice out k,v from one big set of gradients entering the input linear's bprop
        # (should only impact meta data, no copies!)
        # The gradients are identical in size to the Input Linear outputs.
        # The tensor is declared before hand to properly slice out query, key, and value grads.
        input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)
        queries_grads = torch.empty_like(queries)
        keys_grads = input_lin_kv_results_grads[:, :, 0, :]
        values_grads = input_lin_kv_results_grads[:, :, 1, :]

        # Output Linear GEMM - DGRAD
        # Input1: (data grads)  [seql_q, bsz, embed_dim=heads*head_dim]
        # Input2: (weights)     [ embed_dim, embed_dim ]
        # Output:               [ seql_q, seqs, embed_dim ]
        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
        output_lin_grads = torch.mm(
            output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)
        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
        # Output Linear GEMM - WGRAD
        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
        # Input2: (activations) [seql_q*seqs, embed_dim ]
        # Output:               [ seql_q, seqs, embed_dim ]
        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
        output_weight_grads = torch.mm(
            output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1),
            matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
        output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1) * heads_t[0],
                                                 head_dim).transpose(0, 1)

        # Matmul2 - DGRAD1
        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
        matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2))
        # Matmul2 - DGRAD2
        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
        values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1))

        if mask_softmax_dropout_cuda is not None and matmul2_dgrad1.type() == 'torch.cuda.HalfTensor' \
                and len_key <= 2048:

            # This is a safe implementation
            # softmax_grads = mask_softmax_dropout_cuda.backward(heads_t[0], matmul2_dgrad1, softmax_results,
            #                                                    dropout_mask, dropout_prob_t[0])

            # matmul2_dgrad1_ref = matmul2_dgrad1.clone()
            # softmax_results_ref = softmax_results.clone()
            # dropout_grads = torch._masked_scale(matmul2_dgrad1_ref, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
            # softmax_grads_ref = torch._softmax_backward_data(dropout_grads,
            #                                                  softmax_results_ref, -1, softmax_results_ref)

            softmax_grads = mask_softmax_dropout_cuda.backward_recompute(heads_t[0], matmul2_dgrad1, softmax_results,
                                                                         dropout_mask, dropout_prob_t[0])
            # comp = torch.allclose(softmax_grads_ref, softmax_grads, rtol=1e-03, atol=1e-04)
            # if not comp:
            #     # print(softmax_grads_ref - softmax_grads)
            #     print("ERROR: Gradients mismatched.")
            #     print(softmax_grads_ref - softmax_grads)
            # else:
            #     print("Gradients matched.")
        else:
            # Mask and Scaling for Dropout (not a publically documented op)
            dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))

            # Softmax Grad (not a publically documented op)
            softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)

        # Matmul1 - DGRAD1
        # Input1: (data grads)  [seqs*heads, seql_q, seql_k]
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
        queries_grads = torch.baddbmm(queries_grads.transpose(0, 1), softmax_grads, keys.transpose(0, 1),
                                      out=queries_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0])
        # Matmul1 - DGRAD2
        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)
        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
        torch.baddbmm(keys_grads.transpose(0, 1), softmax_grads.transpose(1, 2), queries.transpose(0, 1),
                      out=keys_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0])

        # Input Q Linear GEMM - DGRAD
        # input1: (data grads) [seql_q, seqs, embed_dim(1024)]
        # input2: (weights)    [embed_dim (1024), embed_dim (1024)]
        # output:              [seql_q, seqs, embed_dim]
        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
        queries_grads = queries_grads.transpose(0, 1).view(inputs_q.size(0) * inputs_q.size(1), heads_t[0] * head_dim)
        input_q_grads = torch.mm(queries_grads, input_weights_q)
        input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
        # Input KV Linear GEMM - DGRAD
        # input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]
        # input2: (weights)    [embed_dim*2 (2048), embed_dim (1024)]
        # output:              [seql_k, seqs, embed_dim]
        # GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)
        # the elements of values and query grads are already stored in (shared) query_grads and values_grads
        input_lin_kv_results_grads = input_lin_kv_results_grads.view(inputs_kv.size(0) * inputs_kv.size(1),
                                                                     heads_t[0] * 2 * head_dim)
        input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)
        input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2))
        # Input Q Linear GEMM - WGRAD
        # input1: (data grads)  [seql_q*seqs, embed_dim(1024)]
        # input2: (activations) [seql_q*seqs, embed_dim(1024)]
        # output:               [embed_dim, embed_dim]
        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)
        input_weight_q_grads = torch.mm(queries_grads.transpose(0, 1),
                                        inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)))
        # Input KV Linear GEMM - WGRAD
        # input1: (data grads)  [seql_k*seqs, 2*embed_dim(2048)]
        # input2: (activations) [seql_k*seqs, embed_dim(1024)]
        # output:               [2*embed_dim, embed_dim]
        # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
        input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0, 1),
                                         inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)))

        return None, None, None \
            , input_q_grads, input_kv_grads \
            , input_weight_q_grads, input_weight_kv_grads, output_weight_grads \
            , None, None, None, None, None, None
    def backward(ctx, output_grads, softmax_grads):

        heads_t, scale_t, matmul2_results, dropout_results, softmax_results \
            , q, q_mm, q_r, kv, kv_mm, kv_r \
            , inputs_q, inputs_kv \
            , input_weights_q, input_biases_q, r_q, s_q \
            , input_weights_kv, input_biases_kv, r_kv, s_kv \
            , output_weights, output_biases \
            , dropout_mask, dropout_prob_t \
            = ctx.saved_tensors

        head_dim = inputs_q.size(2) // heads_t[0]

        # Slice out k,v from one big Input Linear output (should only impact meta data, no copies!)
        # Batch sizes and heads are combined to make the batch of the Batched GEMM
        # input_lin_kv_results: [seql_k, bsz, heads(16), 2, head_dim(64)]
        # input_lin_kv_results: [seql_k, batches=bsz*heads, 2, head_dim]
        queries = q.view(inputs_q.size(0),
                         inputs_q.size(1) * heads_t[0], head_dim)
        kv = kv.view(inputs_kv.size(0),
                     inputs_kv.size(1) * heads_t[0], 2, head_dim)
        keys = kv[:, :, 0, :]
        values = kv[:, :, 1, :]

        # Slice out k,v from one big set of gradients entering the input linear's bprop
        # (should only impact meta data, no copies!)
        # The gradients are identical in size to the Input Linear outputs.
        # The tensor is declared before hand to properly slice out query, key, and value grads.
        kv_grads = torch.empty_like(kv)
        queries_grads = torch.empty_like(queries)
        keys_grads = kv_grads[:, :, 0, :]
        values_grads = kv_grads[:, :, 1, :]

        # Output Linear Projection
        o_input = matmul2_results

        # output_lin_grads, output_weights_grads, output_biases_grads, r_o_grads, s_o_grads \
        #     = mm.backward(output_grads, o_input, o_r, o_mm, output_weights, r_o, s_o)
        output_lin_grads = torch.mm(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)), output_weights)
        output_lin_grads = output_lin_grads.view(output_grads.size(0),
                                                 output_grads.size(1),
                                                 output_weights.size(1))
        output_weights_grads = torch.mm(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)).transpose(0, 1),
            matmul2_results.view(
                matmul2_results.size(0) * matmul2_results.size(1),
                matmul2_results.size(2)))
        output_biases_grads = torch.sum(
            output_grads.view(
                output_grads.size(0) * output_grads.size(1),
                output_grads.size(2)), 0)
        output_lin_grads = output_lin_grads.view(
            output_grads.size(0),
            output_grads.size(1) * heads_t[0], head_dim).transpose(0, 1)

        # Matmul2 - DGRAD1
        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )

        # print(output_lin_grads.size(), values.size())
        matmul2_dgrad1 = torch.bmm(output_lin_grads,
                                   values.transpose(0, 1).transpose(1, 2))

        # Matmul2 - DGRAD2
        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
        # Output:               [seqs*heads, seql_q, seql_k]
        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )

        torch.bmm(dropout_results.transpose(1, 2),
                  output_lin_grads,
                  out=values_grads.transpose(0, 1))

        # Mask and Scaling for Dropout (not a publically documented op)
        dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask,
                                            1.0 / (1.0 - dropout_prob_t[0]))

        # Softmax Grad (not a publically documented op)
        softmax_grads = torch._softmax_backward_data(dropout_grads,
                                                     softmax_results, -1,
                                                     softmax_results)

        # Matmul1 - DGRAD1
        # Input1: (data grads)  [seqs*heads, seql_q, seql_k]
        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
        queries_grads = torch.baddbmm(queries_grads.transpose(0, 1),
                                      softmax_grads,
                                      keys.transpose(0, 1),
                                      out=queries_grads.transpose(0, 1),
                                      beta=0.0,
                                      alpha=scale_t[0])
        # Matmul1 - DGRAD2
        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)
        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)
        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )

        torch.baddbmm(keys_grads.transpose(0, 1),
                      softmax_grads.transpose(1, 2),
                      queries.transpose(0, 1),
                      out=keys_grads.transpose(0, 1),
                      beta=0.0,
                      alpha=scale_t[0])

        # Input Q Linear GEMM - DGRAD

        # input1: (data grads) [seql_q, seqs, embed_dim(1024)]
        # input2: (weights)    [embed_dim (1024), embed_dim (1024)]
        # output:              [seql_q, seqs, embed_dim]
        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
        queries_grads = queries_grads.transpose(0, 1).view(
            inputs_q.size(0), inputs_q.size(1), heads_t[0] * head_dim)

        # print("Reached 2 here")
        # print(queries_grads.size(), q_r.size(), q_mm.size())
        inputs_q_grads, input_weights_q_grads, input_biases_q_grads, r_q_grads, s_q_grads \
            = mm.backward(queries_grads, inputs_q, q_r, q_mm, input_weights_q, r_q, s_q)

        kv_grads = kv_grads.view(inputs_kv.size(0), inputs_kv.size(1),
                                 heads_t[0] * 2 * head_dim)

        inputs_kv_grads, input_weights_kv_grads, input_biases_kv_grads, r_kv_grads, s_kv_grads \
            = mm.backward(kv_grads, inputs_kv, kv_r, kv_mm, input_weights_kv, r_kv, s_kv)

        return None, None, None, None \
            , inputs_q_grads, inputs_kv_grads \
            , input_weights_q_grads, input_weights_kv_grads, output_weights_grads \
            , input_biases_q_grads, input_biases_kv_grads, output_biases_grads \
            , r_q_grads, s_q_grads, r_kv_grads, s_kv_grads \
            , None, None, None, None, None