def backward(grad_outputs, softmax_results, dropout_prob_t, heads_t, dropout_mask): len_key = softmax_results.size(-1) if mask_softmax_dropout_cuda is not None and grad_outputs.type() == 'torch.cuda.HalfTensor' \ and len_key <= 2048: softmax_grads = mask_softmax_dropout_cuda.backward_recompute(heads_t[0], grad_outputs, softmax_results, dropout_mask, dropout_prob_t[0]) else: dropout_grads = torch._masked_scale(grad_outputs, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) # be careful we overwrite into "softmax_results" memory here softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) return softmax_grads
def backward(ctx, output_grads, softmax_grads): # def backward(ctx, output_grads): """ :param ctx: :param output_grads: gradients w.r.t the outputs :param softmax_grads: unncessary except we use the attention weights somewhere :return: """ heads_t, \ scale_t, \ matmul2_results, \ dropout_results, \ softmax_results, \ qkv, qkv_mm, qkv_r, \ rpos_r, rpos_mm, \ rw_head_q, rr_head_q, \ inputs, pos, r_head_k, \ input_weights, pos_weights, output_weights, \ r_i, s_i, r_p, s_p, \ r_w_bias, r_r_bias, \ dropout_mask, \ dropout_prob_t = ctx.saved_tensors head_dim = inputs.size(2) // heads_t[0] len_q, bsz = inputs.size(0), inputs.size(1) len_r = pos.size(0) # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) # input_lin_results: [len_q, bsz, heads(16), 3, head_dim(64)] # input_lin_results: [len_q, batches=bsz*heads, 3, head_dim] qkv = qkv.view(inputs.size(0), inputs.size(1) * heads_t[0], 3, head_dim) queries = qkv[:, :, 0, :] keys = qkv[:, :, 1, :] values = qkv[:, :, 2, :] # The tensor is declared before hand to properly slice out query, key, and value grads. qkv_grads = torch.empty_like(qkv) queries_grads = qkv_grads[:, :, 0, :] keys_grads = qkv_grads[:, :, 1, :] values_grads = qkv_grads[:, :, 2, :] # Output Linear Projection o_input = matmul2_results # output_lin_grads, output_weights_grads, output_biases_grads, r_o_grads, s_o_grads \ # = mm.backward(output_grads, o_input, o_r, o_mm, output_weights, r_o, s_o) output_lin_grads = torch.mm( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1)) output_weights_grads = torch.mm( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1), matmul2_results.view( matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1) * heads_t[0], head_dim).transpose(0, 1) output_biases_grads = torch.sum( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0) # Matmul2 - DGRAD1 # Input1: (data grads) [len_q, bsz*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [bsz*heads, len_q, seql_k] # GEMM: Per batch: ( len_q x head_dim ) x ( head_dim x seql_k ) = ( len_q x seql_k ) matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2)) # Matmul2 - DGRAD2 # Input1: (data grads) [len_q, bsz*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [bsz*heads, len_q, seql_k] # GEMM: Per batch: ( len_q x head_dim ) x ( head_dim x seql_k ) = ( len_q x seql_k ) torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1)) # print("Reached here") # Mask and Scaling for Dropout (not a publically documented op) if dropout_prob_t[0] > 0.0: dropout_grads = torch._masked_scale( matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) else: dropout_grads = matmul2_dgrad1 # Softmax Grad (not a publically documented op) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) attn_score_grads = softmax_grads # the grads are evenly distributed to AC and BD matmul_ac_grads = attn_score_grads # Matmul1 - DGRAD1 # Input1: (data grads) [bsz*heads, len_q, seql_k] # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1) # Output: [bsz*heads, len_q, head_dim] transpose(0,1) # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = ( len_q x head_dim ) torch.baddbmm(queries_grads.transpose(0, 1), matmul_ac_grads, keys.transpose(0, 1), out=queries_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) queries_grads_ac = queries_grads r_w_bias_grads = torch.sum(queries_grads_ac.view( len_q, bsz, heads_t[0], -1), dim=[0, 1]) # heads * head_dim matmul_bd_grads = attn_score_grads if len_r > len_q: # if we cut off the BDs from before, then put the zero gradients behind grad_cut = matmul_bd_grads.new_zeros( (matmul_bd_grads.size(0), matmul_bd_grads.size(1), len_r - len_q)) matmul_bd_grads = torch.cat([matmul_bd_grads, grad_cut], dim=-1) # backprop through the shifting matmul_bd_grads = RelativeShift.backward(matmul_bd_grads, True, False) # Matmul1 - DGRAD1 # Input1: (matmul_bd_grads) [bsz*heads, len_q, seql_k] # Input2: (r_head_k) [len_q, bsz*heads, head_dim] transpose(0,1) # Output: [bsz*heads, len_q, head_dim] transpose(0,1) # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = ( len_q x head_dim ) queries_grads_bd = queries_grads.new_empty(*queries_grads.size()) torch.baddbmm(queries_grads_bd.transpose(0, 1), matmul_bd_grads, r_head_k.transpose(0, 1), out=queries_grads_bd.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # len_q x batch*heads x d_head r_r_bias_grads = torch.sum(queries_grads_bd.view( len_q, bsz, heads_t[0], -1), dim=[0, 1]) # add the gradients from bd to queries queries_grads.add_(queries_grads_bd) # # MatmulAC - DGAD2 # Input1: (data grads) [bsz*heads, len_q, seql_k] transpose(1,2) # Input2: (rw_head_q) [bsz*heads, head_dim, len_q] transpose(0,1) # Output: [seql_k, bsz*heads, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_k x len_q ) x ( len_q x head_dim ) = ( seql_k x head_dim ) torch.baddbmm(keys_grads.transpose(0, 1), matmul_ac_grads.transpose(1, 2), rw_head_q.transpose(0, 1), out=keys_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # MatmulBD - DGRAD2 # Input1: (data grads) [bsz*heads, len_q, len_r] transpose(1,2) # Input2: (rr_head_q) [len_q, bsz*heads, head_dim] transpose(0,1) # Output: r_head_k [len_r, bsz*heads, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_k x len_q ) x ( len_q x head_dim ) = ( seql_k x head_dim ) r_head_k_grad = r_head_k.new_empty((len_r, bsz * heads_t[0], head_dim)) # rr_head_q = queries.view(len_q, bsz, heads_t[0], head_dim) + r_r_bias # # rr_head_q = rr_head_q.view(len_q, bsz * heads_t[0], head_dim) torch.baddbmm(r_head_k_grad.transpose(0, 1), matmul_bd_grads.transpose(1, 2).contiguous(), rr_head_q.transpose(0, 1), out=r_head_k_grad.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # r_head_k_grad = torch.matmul(matmul_bd_grads.transpose(1, 2), rr_head_q.transpose(0, 1)) r_head_k_grad = r_head_k_grad.view(len_r, bsz, heads_t[0] * head_dim) # Input Linear GEMM - DGRAD # input1: (data grads) [len_q, bsz, 3*embed_dim(3072)] # input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] # output: [len_q, bsz, embed_dim] # GEMM: ( (len_q*bsz) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (len_q*bsz x embed_dim) qkv_grads = qkv_grads.view(inputs.size(0), inputs.size(1), heads_t[0] * 3 * head_dim) input_grads, input_weights_grads, input_biases_grads, r_i_grads, s_i_grads = \ mm.backward(qkv_grads, inputs, qkv_r, qkv_mm, input_weights, r_i, s_i) _, pos_weights_grads, pos_biases_grads, r_p_grads, s_p_grads = \ mm.backward(r_head_k_grad, pos, rpos_r, rpos_mm, pos_weights, r_p, s_p, need_grad_x=False) return input_grads, None, None, None, None, None, \ input_weights_grads, output_weights_grads, pos_weights_grads, \ input_biases_grads, output_biases_grads, pos_biases_grads, \ r_i_grads, s_i_grads, r_p_grads, s_p_grads, \ r_w_bias_grads, r_r_bias_grads, \ None, None, None, None, None
def backward(ctx, output_grads, softmax_grads): heads_t, scale_t, matmul2_results, dropout_results, softmax_results \ , input_lin_q_results, input_lin_kv_results \ , inputs_q, inputs_kv \ , input_weights_q, input_weights_kv, output_weights \ , dropout_mask, dropout_prob_t \ = ctx.saved_tensors head_dim = inputs_q.size(2) // heads_t[0] # Slice out k,v from one big Input Linear output (should only impact meta data, no copies!) # Batch sizes and heads are combined to make the batch of the Batched GEMM # input_lin_kv_results: [seql_k, bsz, heads(16), 2, head_dim(64)] # input_lin_kv_results: [seql_k, batches=bsz*heads, 2, head_dim] queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads_t[0], head_dim) input_lin_kv_results = input_lin_kv_results.view( inputs_kv.size(0), inputs_kv.size(1) * heads_t[0], 2, head_dim) keys = input_lin_kv_results[:, :, 0, :] values = input_lin_kv_results[:, :, 1, :] # Slice out k,v from one big set of gradients entering the input linear's bprop # (should only impact meta data, no copies!) # The gradients are identical in size to the Input Linear outputs. # The tensor is declared before hand to properly slice out query, key, and value grads. input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results) queries_grads = torch.empty_like(queries) keys_grads = input_lin_kv_results_grads[:, :, 0, :] values_grads = input_lin_kv_results_grads[:, :, 1, :] # Output Linear GEMM - DGRAD # Input1: (data grads) [seql_q, bsz, embed_dim=heads*head_dim] # Input2: (weights) [ embed_dim, embed_dim ] # Output: [ seql_q, seqs, embed_dim ] # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) output_lin_grads = torch.mm( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1)) # Output Linear GEMM - WGRAD # Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1) # Input2: (activations) [seql_q*seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ] # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) output_weight_grads = torch.mm( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1), matmul2_results.view( matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) output_lin_grads = output_lin_grads.view( output_grads.size(0), output_grads.size(1) * heads_t[0], head_dim).transpose(0, 1) # Matmul2 - DGRAD1 # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2)) # Matmul2 - DGRAD2 # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1)) # Mask and Scaling for Dropout (not a publically documented op) dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) # Softmax Grad (not a publically documented op) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # Matmul1 - DGRAD1 # Input1: (data grads) [seqs*heads, seql_q, seql_k] # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) # Output: [seqs*heads, seql_q, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) queries_grads = torch.baddbmm(queries_grads.transpose(0, 1), softmax_grads, keys.transpose(0, 1), out=queries_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # Matmul1 - DGRAD2 # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) # Output: [seqs*heads, seql_k, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) torch.baddbmm(keys_grads.transpose(0, 1), softmax_grads.transpose(1, 2), queries.transpose(0, 1), out=keys_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # Input Q Linear GEMM - DGRAD # input1: (data grads) [seql_q, seqs, embed_dim(1024)] # input2: (weights) [embed_dim (1024), embed_dim (1024)] # output: [seql_q, seqs, embed_dim] # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) queries_grads = queries_grads.transpose(0, 1).view( inputs_q.size(0) * inputs_q.size(1), heads_t[0] * head_dim) input_q_grads = torch.mm(queries_grads, input_weights_q) input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2)) # Input KV Linear GEMM - DGRAD # input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)] # input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] # output: [seql_k, seqs, embed_dim] # GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim) # the elements of values and query grads are already stored in (shared) query_grads and values_grads input_lin_kv_results_grads = input_lin_kv_results_grads.view( inputs_kv.size(0) * inputs_kv.size(1), heads_t[0] * 2 * head_dim) input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv) input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2)) # Input Q Linear GEMM - WGRAD # input1: (data grads) [seql_q*seqs, embed_dim(1024)] # input2: (activations) [seql_q*seqs, embed_dim(1024)] # output: [embed_dim, embed_dim] # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim) input_weight_q_grads = torch.mm( queries_grads.transpose(0, 1), inputs_q.view( inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2))) # Input KV Linear GEMM - WGRAD # input1: (data grads) [seql_k*seqs, 2*embed_dim(2048)] # input2: (activations) [seql_k*seqs, embed_dim(1024)] # output: [2*embed_dim, embed_dim] # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim) input_weight_kv_grads = torch.mm( input_lin_kv_results_grads.transpose(0, 1), inputs_kv.view( inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2))) return None, None, None \ , input_q_grads, input_kv_grads \ , input_weight_q_grads, input_weight_kv_grads, output_weight_grads \ , None, None, None, None
def backward(ctx, *output_grads): """ :param ctx: :param output_grads: gradients w.r.t the outputs :return: """ if not ctx.recompute: heads_t, \ scale_t, \ matmul2_results, \ dropout_results, \ softmax_results, \ input_lin_results, pos_lin_results, \ rw_head_q, rr_head_q, \ inputs, pos, r_head_k, \ input_weights, pos_weights, \ output_weights, \ dropout_mask, nan_mask, \ dropout_prob_t = ctx.saved_tensors else: heads_t, \ scale_t, \ inputs, pos, r_head_k, \ input_weights, pos_weights, output_weights, \ input_biases, pos_biases, output_biases, \ r_w_bias, r_r_bias, \ dropout_mask, nan_mask, pad_mask, \ dropout_prob_t = ctx.saved_tensors learnable_pos = ctx.learnable_pos if ctx.return_coverage: output_grads, softmax_grads = output_grads else: output_grads = output_grads[0] head_dim = inputs.size(2) // heads_t[0] len_q, bsz = inputs.size(0), inputs.size(1) len_k = len_q len_r = pos.size(0) if ctx.fused_all: # only applicable for learnable position and len_k <= 2048 if ctx.recompute: input_grads, pos_grads, \ input_weight_grads, \ input_bias_grads, \ output_weight_grads, \ output_bias_grads, \ r_w_bias_grads, r_r_bias_grads = rel_self_attn_cuda.backward_recompute( heads_t[0], output_grads, inputs, pos, input_weights, output_weights, input_biases, output_biases, r_w_bias, r_r_bias, dropout_mask, pad_mask, dropout_prob_t[0]) else: input_grads, pos_grads, \ input_weight_grads, \ input_bias_grads, \ output_weight_grads, \ output_bias_grads, \ r_w_bias_grads, r_r_bias_grads = rel_self_attn_cuda.backward( heads_t[0], output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, rw_head_q, rr_head_q, inputs, pos, input_weights, output_weights, dropout_mask, dropout_prob_t[0]) pos_weight_grads = None pos_bias_grads = None return input_grads, pos_grads, None, None, None, input_weight_grads, \ output_weight_grads, pos_weight_grads, \ input_bias_grads, output_bias_grads, pos_bias_grads, r_w_bias_grads, r_r_bias_grads, \ None, None, None, None, None, None, None, None if ctx.recompute: heads = heads_t[0] # Recomputing the activations in the forward pass here input_lin_results = torch.addmm( input_biases, inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0, 1), beta=1., alpha=1.) input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0)) if not learnable_pos: pos_lin_results = torch.addmm(pos_biases, pos.view( pos.size(0) * pos.size(1), pos.size(2)), pos_weights.transpose(0, 1), beta=1., alpha=1.) pos_lin_results = pos_lin_results.view(pos.size(0), pos.size(1), pos_weights.size(0)) r_head_k = pos_lin_results.view(pos.size(0), bsz * heads, head_dim) # T x BxH x D else: # pos_lin_results = pos.view(pos.size(0), bsz * heads, head_dim) # T x BxH x D # r_head_k = pos_lin_results pos_lin_results = None r_head_k = None input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads, 3, head_dim) queries = input_lin_results[:, :, 0, :] keys = input_lin_results[:, :, 1, :] values = input_lin_results[:, :, 2, :] rw_head_q = queries.view(len_q, bsz, heads, head_dim) + r_w_bias # rw_head_q = rw_head_q.view(len_q, bsz * heads, head_dim) matmul_ac = torch.empty( (bsz * heads, queries.size(0), keys.size(0)), dtype=queries.dtype, device=rw_head_q.device) matmul_ac.baddbmm_(rw_head_q.transpose(0, 1), keys.transpose(0, 1).transpose(1, 2), beta=0.0, alpha=scale_t[0]) rr_head_q = queries.view(len_q, bsz, heads, head_dim) + r_r_bias rr_head_q = rr_head_q.view(len_q, bsz * heads, head_dim) if not learnable_pos: matmul_bd = torch.empty((bsz * heads, queries.size(0), len_r), dtype=queries.dtype, device=rw_head_q.device) matmul_bd.baddbmm_(rr_head_q.transpose(0, 1), r_head_k.transpose(0, 1).transpose(1, 2), beta=0.0, alpha=scale_t[0]) matmul_bd = RelativeShift.forward(matmul_bd, True, False) matmul_bd = matmul_bd[:, :, :len_k] attn_score = matmul_ac + matmul_bd else: matmul_ac.transpose(0, 1).baddbmm_(rr_head_q, pos.transpose(1, 2), beta=1.0, alpha=scale_t[0]) attn_score = matmul_ac if pad_mask is not None: attn_score.view(bsz, heads, len_q, len_k).masked_fill_(pad_mask, float('-inf')) softmax_results = F.softmax(attn_score, dim=-1).type_as(attn_score) del attn_score pinv = 1.0 / (1.0 - dropout_prob_t[0]) dropout_results = softmax_results * dropout_mask * pinv matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1)).transpose(0, 1) matmul2_results = matmul2_results.contiguous().view( inputs.size(0), inputs.size(1), inputs.size(2)) # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) # input_lin_results: [len_q, bsz, heads(16), 3, head_dim(64)] # input_lin_results: [len_q, batches=bsz*heads, 3, head_dim] input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads_t[0], 3, head_dim) queries = input_lin_results[:, :, 0, :] keys = input_lin_results[:, :, 1, :] values = input_lin_results[:, :, 2, :] # The tensor is declared before hand to properly slice out query, key, and value grads. input_lin_results_grads = torch.empty_like(input_lin_results) queries_grads = input_lin_results_grads[:, :, 0, :] keys_grads = input_lin_results_grads[:, :, 1, :] values_grads = input_lin_results_grads[:, :, 2, :] # Output Linear GEMM - DGRAD # Input1: (data grads) [len_q, bsz, embed_dim=heads*head_dim] # Input2: (weights) [ embed_dim, embed_dim ] # Output: [ len_q, bsz, embed_dim ] # GEMM: ( len_q*bsz x embed_dim ) x ( embed_dim x embed_dim ) = ( len_q*bsz x embed_dim ) output_lin_grads = torch.mm( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1)) # Output Linear GEMM - WGRAD # Input1: (data grads) [len_q*bsz, embed_dim=heads*head_dim] transpose(0,1) # Input2: (activations) [len_q*bsz, embed_dim ] # Output: [ len_q, bsz, embed_dim ] # GEMM: ( embed_dim x len_q*bsz ) x ( len_q*bsz x embed_dim ) = ( embed_dim x embed_dim ) output_weight_grads = torch.mm( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1), matmul2_results.view( matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1) * heads_t[0], head_dim).transpose(0, 1) output_bias_grads = torch.sum( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0) # Matmul2 - DGRAD1 # Input1: (data grads) [bsz*heads, len_q, head_dim] # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [bsz*heads, len_q, seql_k] # GEMM: Per batch: ( len_q x head_dim ) x ( head_dim x seql_k ) = ( len_q x seql_k ) matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2)) # Matmul2 - DGRAD2 # Input2: (data grads) [bsz*heads, len_q, head_dim] # Input1: (activations) [bsz*heads, len_q, len_k] transpose(1,2) # Output: [bsz*heads, len_k, head_dim] # GEMM: Per batch: ( len_k x len_q ) x ( len_q x head_dim ) = ( len_k x head_dim ) torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1)) # Input1: (data grads) [bsz*heads, len_q, head_dim].transpose(0, 1) # Input2: (rpositions) [len_q, len_k, head_dim].transpose(1,2) # Output: [bsz*heads, len_q, seql_k].transpose(0, 1) # torch.baddbmm(matmul2_dgrad1.transpose(0, 1), output_lin_grads.transpose(0, 1), pos.transpose(1, 2), # beta=1.0, alpha=1.0, out=matmul2_dgrad1.transpose(0, 1)) # Input2: (data grads) [bsz*heads, len_q, head_dim].transpose(0, 1) # Input1: (activations) [bsz*heads, len_q, len_k] transpose(0,1).transpose(1,2) # Output: [len_q, len_k, head_dim] # pos_grads = torch.bmm(dropout_results.transpose(0, 1).transpose(1, 2), output_lin_grads.transpose(0, 1)) if mask_softmax_dropout_cuda is not None and len_k <= 2048 \ and matmul2_dgrad1.type() == 'torch.cuda.HalfTensor' and not ctx.double_precision: # verification code # matmul2_dgrad1_ref = matmul2_dgrad1.clone() # softmax_results_ref = softmax_results.clone() # dropout_grads = torch._masked_scale(matmul2_dgrad1_ref, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) # softmax_grads_ref = torch._softmax_backward_data(dropout_grads, # softmax_results_ref, -1, softmax_results_ref) # softmax_grads_ref = mask_softmax_dropout_cuda.backward(heads_t[0], matmul2_dgrad1_ref, softmax_results, # dropout_mask, dropout_prob_t[0]) # note this function doesn't really do recompute # because if we don't store softmax results, we have to store matmul2_dgrad1 # so the only difference is that we allocated more memory, but the memory kept for backward is the same softmax_grads = mask_softmax_dropout_cuda.backward_recompute( heads_t[0], matmul2_dgrad1, softmax_results, dropout_mask, dropout_prob_t[0]) # verification code (next) # comp = torch.allclose(softmax_grads_ref, softmax_grads, rtol=1e-03, atol=1e-04) # if not comp: # print("ERROR: Gradients mismatched.") # print(softmax_grads_ref - softmax_grads) # else: # print("Gradients matched.") else: # Mask and Scaling for Dropout (not a publically documented op) if dropout_prob_t[0] > 0.0: dropout_grads = torch._masked_scale( matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) else: dropout_grads = matmul2_dgrad1 # Softmax Grad (not a publically documented op) softmax_grads = torch._softmax_backward_data( dropout_grads, softmax_results, -1, softmax_results) attn_score_grads = softmax_grads # the grads are evenly distributed to AC and BD matmul_ac_grads = attn_score_grads # Matmul1 - DGRAD1 # Input1: (data grads) [bsz*heads, len_q, seql_k] # Input2: (activations) [seql_k, bsz*heads, head_dim] transpose(0,1) # Output: [bsz*heads, len_q, head_dim] transpose(0,1) # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = ( len_q x head_dim ) torch.baddbmm(queries_grads.transpose(0, 1), matmul_ac_grads, keys.transpose(0, 1), out=queries_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) queries_grads_ac = queries_grads r_w_bias_grads = torch.sum(queries_grads_ac.view( len_q, bsz, heads_t[0], -1), dim=[0, 1]) # heads * head_dim matmul_bd_grads = attn_score_grads if not learnable_pos: if len_r > len_q: # if we cut off the BDs from before, then put the zero gradients behind grad_cut = matmul_bd_grads.new_zeros( (matmul_bd_grads.size(0), matmul_bd_grads.size(1), len_r - len_q)) matmul_bd_grads = torch.cat([matmul_bd_grads, grad_cut], dim=-1) # backprop through the shifting matmul_bd_grads = RelativeShift.backward(matmul_bd_grads, True, False) # MatmulBD - DGRAD1 # Input1: (matmul_bd_grads) [bsz*heads, len_q, seql_k] # Input2: (r_head_k) [len_q, bsz*heads, head_dim] transpose(0,1) # Output: [bsz*heads, len_q, head_dim] transpose(0,1) # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = ( len_q x head_dim ) queries_grads_bd = queries_grads.new_empty(*queries_grads.size()) torch.baddbmm(queries_grads_bd.transpose(0, 1), matmul_bd_grads, r_head_k.transpose(0, 1), out=queries_grads_bd.transpose(0, 1), beta=0.0, alpha=scale_t[0]) else: # MatmulBD - DGRAD1 # Input1: (matmul_bd_grads) [bsz*heads, len_q, len_k] transpose(0,1) # Input2: (pos) [len_q, len_k, head_dim] # Output: [len_q, bsz*heads, head_dim] # GEMM: Per batch: ( bsz*heads x len_k ) x ( len_k x head_dim ) = ( bsz*heads x head_dim ) queries_grads_bd = queries_grads.new_empty(*queries_grads.size()) torch.baddbmm(queries_grads_bd, matmul_bd_grads.transpose(0, 1), pos, out=queries_grads_bd, beta=0.0, alpha=scale_t[0]) # queries_grads_bd = torch.bmm(matmul_bd_grads.transpose(0, 1), pos).mul_(scale_t[0]) # len_q x batch*heads x d_head r_r_bias_grads = torch.sum(queries_grads_bd.view( len_q, bsz, heads_t[0], -1), dim=[0, 1]) # add the gradients from bd to queries queries_grads.add_(queries_grads_bd) # # MatmulAC - DGAD2 # Input1: (data grads) [bsz*heads, len_q, seql_k] transpose(1,2) # Input2: (rw_head_q) [bsz*heads, head_dim, len_q] transpose(0,1) # Output: [seql_k, bsz*heads, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_k x len_q ) x ( len_q x head_dim ) = ( seql_k x head_dim ) torch.baddbmm(keys_grads.transpose(0, 1), matmul_ac_grads.transpose(1, 2), rw_head_q.transpose(0, 1), out=keys_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) if not learnable_pos: # MatmulBD - DGRAD2 # Input1: (data grads) [bsz*heads, len_q, len_r] transpose(1,2) # Input2: (rr_head_q) [len_q, bsz*heads, head_dim] transpose(0,1) # Output: r_head_k [len_r, bsz*heads, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_k x len_q ) x ( len_q x head_dim ) = ( seql_k x head_dim ) r_head_k_grad = r_head_k.new_empty( (len_r, bsz * heads_t[0], head_dim)) torch.baddbmm(r_head_k_grad.transpose(0, 1), matmul_bd_grads.transpose(1, 2).contiguous(), rr_head_q.transpose(0, 1), out=r_head_k_grad.transpose(0, 1), beta=0.0, alpha=scale_t[0]) r_head_k_grad = r_head_k_grad.view(len_r, bsz, heads_t[0], head_dim). \ view(len_r * bsz, heads_t[0] * head_dim) pos_weight_grads = torch.mm( r_head_k_grad.transpose(0, 1), pos.view(pos.size(0) * pos.size(1), pos.size(2))) pos_bias_grads = torch.sum(r_head_k_grad, 0) pos_grads = None else: pos_weight_grads, pos_bias_grads = None, None pos_grads = torch.empty_like(pos) # MatmulBD - DGRAD2 # Input1: (data grads) [bsz*heads, len_q, len_k] transpose(0,1),(1,2) -> [len_q, len_k, bsz*heads] # Input2: (rr_head_q) [len_q, bsz*heads, head_dim] # Output: pos_grads [len_q, len_k, head_dim] # GEMM: Per batch: ( len_k x bsz ) x ( bsz x head_dim ) = ( len_k x head_dim ) torch.baddbmm(pos_grads, matmul_bd_grads.transpose(0, 1).transpose( 1, 2).contiguous(), rr_head_q, out=pos_grads, beta=0.0, alpha=scale_t[0]) # Input Linear GEMM - DGRAD # input1: (data grads) [len_q, bsz, 3*embed_dim(3072)] # input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] # output: [len_q, bsz, embed_dim] # GEMM: ( (len_q*bsz) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (len_q*bsz x embed_dim) input_lin_results_grads = input_lin_results_grads.view( inputs.size(0) * inputs.size(1), heads_t[0] * 3 * head_dim) input_grads = torch.mm(input_lin_results_grads, input_weights) input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2)) # Input Linear GEMM - WGRAD # input1: (data grads) [len_q*bsz, 3*embed_dim(3072)] # input2: (activations) [len_q*bsz, embed_dim(1024)] # output: [3*embed_dim, embed_dim] # GEMM: ( 3*embed_dim x len_q*bsz ) x ( len_q*bsz x embed_dim ) = (3*embed_dim x embed_dim) input_weight_grads = torch.mm( input_lin_results_grads.transpose(0, 1), inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2))) input_bias_grads = torch.sum(input_lin_results_grads, 0) return input_grads, pos_grads, None, None, None, input_weight_grads, output_weight_grads, pos_weight_grads, \ input_bias_grads, output_bias_grads, pos_bias_grads, r_w_bias_grads, r_r_bias_grads, \ None, None, None, None, None, None, None, None
def backward(ctx, output_grads): use_biases_t, \ heads_t, \ scale_t, \ matmul2_results, \ dropout_results, \ softmax_results, \ input_lin_results, \ inputs, \ input_weights, \ output_weights, \ dropout_mask, \ dropout_prob_t = ctx.saved_tensors head_dim = inputs.size(2) // heads_t[0] # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Sequences and heads are combined to make the batch of the Batched GEMM # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)] # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim] input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads_t[0], 3, head_dim) queries = input_lin_results[:,:,0,:] keys = input_lin_results[:,:,1,:] values = input_lin_results[:,:,2,:] # Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!) # The gradients are identical in size to the Input Linear outputs. # The tensor is declared before hand to properly slice out query, key, and value grads. input_lin_results_grads = torch.empty_like(input_lin_results) queries_grads = input_lin_results_grads[:,:,0,:] keys_grads = input_lin_results_grads[:,:,1,:] values_grads = input_lin_results_grads[:,:,2,:] # Output Linear GEMM - DGRAD # Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim] # Input2: (weights) [ embed_dim, embed_dim ] # Output: [ seql_q, seqs, embed_dim ] # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1)) # Output Linear GEMM - WGRAD # Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1) # Input2: (activations) [seql_q*seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ] # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1), matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1) if use_biases_t[0]: output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0) else: output_bias_grads = None # Matmul2 - DGRAD1 # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2)) # Matmul2 - DGRAD2 # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1)) # Mask and Scaling for Dropout (not a publically documented op) dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0]) # Softmax Grad (not a publically documented op) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # Matmul1 - DGRAD1 # Input1: (data grads) [seqs*heads, seql_q, seql_k] # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) # Output: [seqs*heads, seql_q, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1), out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) # Matmul1 - DGRAD2 # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) # Output: [seqs*heads, seql_k, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1), out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) # Input Linear GEMM - DGRAD # input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)] # input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] # output: [seql_q, seqs, embed_dim] # GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) input_lin_results_grads = input_lin_results_grads.view(inputs.size(0)*inputs.size(1), heads_t[0]*3*head_dim) input_grads = torch.mm(input_lin_results_grads, input_weights) input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2)) # Input Linear GEMM - WGRAD # input1: (data grads) [seql_q*seqs, 3*embed_dim(3072)] # input2: (activations) [seql_q*seqs, embed_dim(1024)] # output: [3*embed_dim, embed_dim] # GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim) input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2))) if use_biases_t[0]: input_bias_grads = torch.sum(input_lin_results_grads, 0) else: input_bias_grads = None return None, None, None, None, \ input_grads, \ input_weight_grads, output_weight_grads, \ input_bias_grads, output_bias_grads, \ None, None
def backward(ctx, *output_grads): incremental = ctx.incremental len_q = ctx.len_q len_key = ctx.len_k if ctx.return_coverage: output_grads, coverage_grads = output_grads else: output_grads = output_grads[0] heads_t, scale_t, matmul2_results, dropout_results, softmax_results, \ input_lin_q_results, input_lin_kv_results, \ inputs_q, inputs_kv, \ input_weights_q, input_weights_kv, output_weights, \ dropout_mask, dropout_prob_t \ = ctx.saved_tensors head_dim = inputs_q.size(2) // heads_t[0] bsz = inputs_q.size(1) if ctx.fused_all: assert encdec_multihead_attn_cuda is not None and len_key <= 2048 input_q_grads, \ input_kv_grads, \ input_weight_q_grads, \ input_weight_kv_grads, \ output_weight_grads = encdec_multihead_attn_cuda.backward(heads_t[0], output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, input_lin_kv_results, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, dropout_mask, dropout_prob_t[0]) return None, None, None, \ input_q_grads, input_kv_grads, \ input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \ None, None, None, None, None, None # Slice out k,v from one big Input Linear output (should only impact meta data, no copies!) # Batch sizes and heads are combined to make the batch of the Batched GEMM # input_lin_kv_results: [seql_k, bsz, heads(16), 2, head_dim(64)] # input_lin_kv_results: [seql_k, batches=bsz*heads, 2, head_dim] queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads_t[0], head_dim) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads_t[0], 2, head_dim) keys = input_lin_kv_results[:, :, 0, :] values = input_lin_kv_results[:, :, 1, :] # Slice out k,v from one big set of gradients entering the input linear's bprop # (should only impact meta data, no copies!) # The gradients are identical in size to the Input Linear outputs. # The tensor is declared before hand to properly slice out query, key, and value grads. input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results) queries_grads = torch.empty_like(queries) keys_grads = input_lin_kv_results_grads[:, :, 0, :] values_grads = input_lin_kv_results_grads[:, :, 1, :] # Output Linear GEMM - DGRAD # Input1: (data grads) [seql_q, bsz, embed_dim=heads*head_dim] # Input2: (weights) [ embed_dim, embed_dim ] # Output: [ seql_q, seqs, embed_dim ] # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) output_lin_grads = torch.mm( output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1)) # Output Linear GEMM - WGRAD # Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1) # Input2: (activations) [seql_q*seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ] # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) output_weight_grads = torch.mm( output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1), matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1) * heads_t[0], head_dim).transpose(0, 1) # Matmul2 - DGRAD1 # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2)) # Matmul2 - DGRAD2 # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1)) if mask_softmax_dropout_cuda is not None and matmul2_dgrad1.type() == 'torch.cuda.HalfTensor' \ and len_key <= 2048: # This is a safe implementation # softmax_grads = mask_softmax_dropout_cuda.backward(heads_t[0], matmul2_dgrad1, softmax_results, # dropout_mask, dropout_prob_t[0]) # matmul2_dgrad1_ref = matmul2_dgrad1.clone() # softmax_results_ref = softmax_results.clone() # dropout_grads = torch._masked_scale(matmul2_dgrad1_ref, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) # softmax_grads_ref = torch._softmax_backward_data(dropout_grads, # softmax_results_ref, -1, softmax_results_ref) softmax_grads = mask_softmax_dropout_cuda.backward_recompute(heads_t[0], matmul2_dgrad1, softmax_results, dropout_mask, dropout_prob_t[0]) # comp = torch.allclose(softmax_grads_ref, softmax_grads, rtol=1e-03, atol=1e-04) # if not comp: # # print(softmax_grads_ref - softmax_grads) # print("ERROR: Gradients mismatched.") # print(softmax_grads_ref - softmax_grads) # else: # print("Gradients matched.") else: # Mask and Scaling for Dropout (not a publically documented op) dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) # Softmax Grad (not a publically documented op) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # Matmul1 - DGRAD1 # Input1: (data grads) [seqs*heads, seql_q, seql_k] # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) # Output: [seqs*heads, seql_q, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) queries_grads = torch.baddbmm(queries_grads.transpose(0, 1), softmax_grads, keys.transpose(0, 1), out=queries_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # Matmul1 - DGRAD2 # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) # Output: [seqs*heads, seql_k, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) torch.baddbmm(keys_grads.transpose(0, 1), softmax_grads.transpose(1, 2), queries.transpose(0, 1), out=keys_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # Input Q Linear GEMM - DGRAD # input1: (data grads) [seql_q, seqs, embed_dim(1024)] # input2: (weights) [embed_dim (1024), embed_dim (1024)] # output: [seql_q, seqs, embed_dim] # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) queries_grads = queries_grads.transpose(0, 1).view(inputs_q.size(0) * inputs_q.size(1), heads_t[0] * head_dim) input_q_grads = torch.mm(queries_grads, input_weights_q) input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2)) # Input KV Linear GEMM - DGRAD # input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)] # input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] # output: [seql_k, seqs, embed_dim] # GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim) # the elements of values and query grads are already stored in (shared) query_grads and values_grads input_lin_kv_results_grads = input_lin_kv_results_grads.view(inputs_kv.size(0) * inputs_kv.size(1), heads_t[0] * 2 * head_dim) input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv) input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2)) # Input Q Linear GEMM - WGRAD # input1: (data grads) [seql_q*seqs, embed_dim(1024)] # input2: (activations) [seql_q*seqs, embed_dim(1024)] # output: [embed_dim, embed_dim] # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim) input_weight_q_grads = torch.mm(queries_grads.transpose(0, 1), inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2))) # Input KV Linear GEMM - WGRAD # input1: (data grads) [seql_k*seqs, 2*embed_dim(2048)] # input2: (activations) [seql_k*seqs, embed_dim(1024)] # output: [2*embed_dim, embed_dim] # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim) input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0, 1), inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2))) return None, None, None \ , input_q_grads, input_kv_grads \ , input_weight_q_grads, input_weight_kv_grads, output_weight_grads \ , None, None, None, None, None, None
def backward(ctx, output_grads, softmax_grads): heads_t, scale_t, matmul2_results, dropout_results, softmax_results \ , q, q_mm, q_r, kv, kv_mm, kv_r \ , inputs_q, inputs_kv \ , input_weights_q, input_biases_q, r_q, s_q \ , input_weights_kv, input_biases_kv, r_kv, s_kv \ , output_weights, output_biases \ , dropout_mask, dropout_prob_t \ = ctx.saved_tensors head_dim = inputs_q.size(2) // heads_t[0] # Slice out k,v from one big Input Linear output (should only impact meta data, no copies!) # Batch sizes and heads are combined to make the batch of the Batched GEMM # input_lin_kv_results: [seql_k, bsz, heads(16), 2, head_dim(64)] # input_lin_kv_results: [seql_k, batches=bsz*heads, 2, head_dim] queries = q.view(inputs_q.size(0), inputs_q.size(1) * heads_t[0], head_dim) kv = kv.view(inputs_kv.size(0), inputs_kv.size(1) * heads_t[0], 2, head_dim) keys = kv[:, :, 0, :] values = kv[:, :, 1, :] # Slice out k,v from one big set of gradients entering the input linear's bprop # (should only impact meta data, no copies!) # The gradients are identical in size to the Input Linear outputs. # The tensor is declared before hand to properly slice out query, key, and value grads. kv_grads = torch.empty_like(kv) queries_grads = torch.empty_like(queries) keys_grads = kv_grads[:, :, 0, :] values_grads = kv_grads[:, :, 1, :] # Output Linear Projection o_input = matmul2_results # output_lin_grads, output_weights_grads, output_biases_grads, r_o_grads, s_o_grads \ # = mm.backward(output_grads, o_input, o_r, o_mm, output_weights, r_o, s_o) output_lin_grads = torch.mm( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1)) output_weights_grads = torch.mm( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1), matmul2_results.view( matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) output_biases_grads = torch.sum( output_grads.view( output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0) output_lin_grads = output_lin_grads.view( output_grads.size(0), output_grads.size(1) * heads_t[0], head_dim).transpose(0, 1) # Matmul2 - DGRAD1 # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) # print(output_lin_grads.size(), values.size()) matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2)) # Matmul2 - DGRAD2 # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) # Output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1)) # Mask and Scaling for Dropout (not a publically documented op) dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) # Softmax Grad (not a publically documented op) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # Matmul1 - DGRAD1 # Input1: (data grads) [seqs*heads, seql_q, seql_k] # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) # Output: [seqs*heads, seql_q, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) queries_grads = torch.baddbmm(queries_grads.transpose(0, 1), softmax_grads, keys.transpose(0, 1), out=queries_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # Matmul1 - DGRAD2 # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) # Output: [seqs*heads, seql_k, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) torch.baddbmm(keys_grads.transpose(0, 1), softmax_grads.transpose(1, 2), queries.transpose(0, 1), out=keys_grads.transpose(0, 1), beta=0.0, alpha=scale_t[0]) # Input Q Linear GEMM - DGRAD # input1: (data grads) [seql_q, seqs, embed_dim(1024)] # input2: (weights) [embed_dim (1024), embed_dim (1024)] # output: [seql_q, seqs, embed_dim] # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) queries_grads = queries_grads.transpose(0, 1).view( inputs_q.size(0), inputs_q.size(1), heads_t[0] * head_dim) # print("Reached 2 here") # print(queries_grads.size(), q_r.size(), q_mm.size()) inputs_q_grads, input_weights_q_grads, input_biases_q_grads, r_q_grads, s_q_grads \ = mm.backward(queries_grads, inputs_q, q_r, q_mm, input_weights_q, r_q, s_q) kv_grads = kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), heads_t[0] * 2 * head_dim) inputs_kv_grads, input_weights_kv_grads, input_biases_kv_grads, r_kv_grads, s_kv_grads \ = mm.backward(kv_grads, inputs_kv, kv_r, kv_mm, input_weights_kv, r_kv, s_kv) return None, None, None, None \ , inputs_q_grads, inputs_kv_grads \ , input_weights_q_grads, input_weights_kv_grads, output_weights_grads \ , input_biases_q_grads, input_biases_kv_grads, output_biases_grads \ , r_q_grads, s_q_grads, r_kv_grads, s_kv_grads \ , None, None, None, None, None