def forward(inputs, double_precision, dropout_prob, is_training, heads): """ :param heads: :param is_training: :param dropout_prob: :param inputs: :param double_precision: :return: """ len_k = inputs.size(-1) if mask_softmax_dropout_cuda and len_k <= 2048 and inputs.type() == 'torch.cuda.HalfTensor': dropout_mask, softmax_results, dropout_results = \ mask_softmax_dropout_cuda.forward(is_training, heads, inputs, dropout_prob) if is_training: dropout_results = softmax_results else: dtype_ = torch.float64 if double_precision else torch.float32 softmax_results = F.softmax(inputs, dim=-1, dtype=dtype_) # Dropout - is not executed for inference if is_training: dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1. - dropout_prob_t[0])) else: dropout_results = softmax_results dropout_mask = torch.tensor([]) return dropout_mask, softmax_results, dropout_results
def forward(ctx, inputs, pos, use_time_mask, is_training, heads, input_weights, output_weights, pos_weights, input_biases, output_biases, pos_biases, r_w_bias, r_r_bias, mask, dropout_prob, incremental, incremental_cache, double_precision, learnable_pos, return_coverage, recompute): """ :param recompute: :param return_coverage: :param learnable_pos: :param double_precision: ops at float64, only for debugging :param ctx: context object to stash information for backward :param inputs: input hidden states [len_q x batch_size x hidden] :param pos: [len_k x 1 x hidden] :param use_time_mask: bool, if we use the causal mask for decoder :param is_training: training state, for dropout :param heads: number of heads :param input_weights: weight matrix [hidden x 3*hidden] :param output_weights: output weight [hidden x hidden] :param input_biases: bias [3*hidden] :param output_biases: output bias [bias] :param pos_biases: :param pos_weights: :param r_w_bias: :param r_r_bias: :param mask: None or [B x T] or [T x T] :param dropout_prob: :param incremental: :param incremental_cache: :return: """ heads_t = torch.tensor([heads]) dropout_prob_t = torch.tensor([dropout_prob]) null_tensor = torch.tensor([]).to(inputs.device) head_dim = inputs.size(2) // heads scale_t = torch.tensor([head_dim**-0.5]) ctx.double_precision = double_precision ctx.fused_softmax_dropout = False ctx.learnable_pos = learnable_pos ctx.return_coverage = return_coverage ctx.fused_all = False ctx.recompute = recompute bsz, len_q = inputs.size(1), inputs.size(0) len_r = pos.size( 0 ) # r can be longer than query, i.e for bidirectional attention we need 2k+1 positions len_k = len_q # because of self-attention if mask is not None: mask = mask.to(torch.bool) # Self Attention Time Mask if use_time_mask: assert (len(mask.size()) == 2), "Timing mask is not 2D!" # assert (mask.size(0) == mask.size(1)), "Sequence length should match!" mask = mask.unsqueeze(0).unsqueeze(0) # Key Padding Mask else: # attn_score = attn_score.view(bsz, heads, len_q, len_k) mask = mask.unsqueeze(1).unsqueeze(2) if rel_self_attn_cuda is not None and not incremental and len_k <= 2048 and \ inputs.type() == 'torch.cuda.HalfTensor' and learnable_pos: input_lin_results, rr_head_q, rw_head_q, \ softmax_results, dropout_results, dropout_mask, \ matmul2_results, outputs \ = rel_self_attn_cuda.forward(is_training, heads, inputs, pos, input_weights, output_weights, input_biases, output_biases, r_w_bias, r_r_bias, mask, dropout_prob) pos_lin_results = None r_head_k = None nan_mask = null_tensor if recompute: ctx.save_for_backward(heads_t, scale_t, inputs, pos, r_head_k, input_weights, pos_weights, output_weights, input_biases, pos_biases, output_biases, r_w_bias, r_r_bias, dropout_mask, nan_mask, mask, dropout_prob_t) else: ctx.save_for_backward(heads_t, scale_t, matmul2_results, dropout_results, softmax_results, input_lin_results, pos_lin_results, rw_head_q, rr_head_q, inputs, pos, r_head_k, input_weights, pos_weights, output_weights, dropout_mask, nan_mask, dropout_prob_t) ctx.fused_all = True if return_coverage: return (outputs, dropout_results) else: return outputs if pos.size(1) == 1 and not learnable_pos: pos = pos.repeat( 1, bsz, 1 ) # we have to use repeat instead of expand here because mm needs contiguous # Input Linear GEMM # input1: (activations) [len_q, bsz, hidden] # input2: (weights) [hidden*3 (3072), hidden (1024)] (transpose [0,1]) # output: [len_q, bsz, hidden*3] # GEMM: ( (len_q*bsz) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (len_q*bsz x embed_dim*3) input_lin_results = torch.addmm(input_biases, inputs.view( inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0, 1), beta=1., alpha=1.) # reshape [len_q*bsz, embed_dim*3 -> len_q x bsz x embed_dim*3] input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0)) # check = torch.allclose(input_lin_results, input_lin_results2) # print("Check linear in", check) if not learnable_pos: pos_lin_results = torch.addmm(pos_biases, pos.view( pos.size(0) * pos.size(1), pos.size(2)), pos_weights.transpose(0, 1), beta=1., alpha=1.) pos_lin_results = pos_lin_results.view(pos.size(0), pos.size(1), pos_weights.size(0)) r_head_k = pos_lin_results.view(pos.size(0), bsz * heads, head_dim) # T x BxH x D else: # pos_lin_results = pos.view(pos.size(0), bsz * heads, head_dim) # T x BxH x D # r_head_k = pos_lin_results pos_lin_results = None r_head_k = None # Slice out q,k,v from one big Input Linear output (should only impact meta data, no copies!) # Sequences and heads are combined to make the batch of the Batched GEMM # input_lin_results: [len_q, bsz, heads(16), 3, head_dim(64)] # input_lin_results: [len_q, batches=bsz*heads, 3, head_dim] input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads, 3, head_dim) queries = input_lin_results[:, :, 0, :] keys = input_lin_results[:, :, 1, :] values = input_lin_results[:, :, 2, :] if incremental: # We have to change the heads x head_dim first and then concat to the T dim # bsz is changed during translation due to beam search # during translation we want to keep the actual T dim in MM as 1 constantly keys = keys.reshape(len_q, bsz, heads * head_dim) values = values.reshape(len_q, bsz, heads * head_dim) if 'k' in incremental_cache and 'v' in incremental_cache: keys = torch.cat([incremental_cache['k'], keys], dim=0) # time first incremental_cache['k'] = keys values = torch.cat([incremental_cache['v'], values], dim=0) # time first incremental_cache['v'] = values else: incremental_cache['k'] = keys incremental_cache['v'] = values keys = keys.view(-1, bsz * heads, head_dim) values = values.view(-1, bsz * heads, head_dim) # re-update len_k to be the newly updated length of the keys len_k = keys.size(0) # Relative Attention from here: # r_w_bias size: head * head_dim rw_head_q = queries.view(len_q, bsz, heads, head_dim) + r_w_bias # rw_head_q = rw_head_q.view(len_q, bsz * heads, head_dim) # matmul_ac batched GEMMs # queries+bias: [len_q, bsz*heads, head_dim] transpose(0, 1) # keys: [len_k, bsz*heads, head_dim] transpose(0, 1) if queries.is_cuda: matmul_ac = torch.empty( (bsz * heads, queries.size(0), keys.size(0)), dtype=queries.dtype, device=rw_head_q.device) matmul_ac = torch.baddbmm(matmul_ac, rw_head_q.transpose(0, 1), keys.transpose(0, 1).transpose(1, 2), out=matmul_ac, beta=0.0, alpha=scale_t[0]) else: matmul_ac = torch.bmm(rw_head_q.transpose(0, 1), keys.transpose(0, 1).transpose(1, 2)).mul_( scale_t[0]) rr_head_q = queries.view(len_q, bsz, heads, head_dim) + r_r_bias # # check = torch.allclose(rr_head_q.view(len_q, bsz, -1), rr_head_q2, rtol=1e-03, atol=1e-04) # print("Check rr_head_q", check) rr_head_q = rr_head_q.view(len_q, bsz * heads, head_dim) if not learnable_pos: if queries.is_cuda: # matmul2 batched GEMMs # queries+bias: [len_q, bsz*heads, head_dim] transpose(0, 1) # rel_positions: [len_r, bsz*heads, head_dim] transpose(0, 1) matmul_bd = torch.empty((bsz * heads, queries.size(0), len_r), dtype=queries.dtype, device=rw_head_q.device) matmul_bd = torch.baddbmm(matmul_bd, rr_head_q.transpose(0, 1), r_head_k.transpose(0, 1).transpose( 1, 2), out=matmul_bd, beta=0.0, alpha=scale_t[0]) else: matmul_bd = torch.matmul(rr_head_q.transpose(0, 1), r_head_k.transpose(0, 1).transpose(1, 2)) \ .mul_(scale_t[0]) # shift so that the relative positions are aligned # the first element will have 0 -1 ... -n relative positions compared to other elements # the last element will have n-1 n-2 ... 0 matmul_bd = RelativeShift.forward(matmul_bd, True, False) # if len_r is longer than len_k, then we need to take the first len_k positions only matmul_bd = matmul_bd[:, :, :len_k] attn_score = matmul_ac + matmul_bd # both AC and BD are scaled with scale_t before in baddbmm else: # matmul2 batched GEMMs # queries+bias: [len_q, bsz*heads, head_dim] # rel_positions: [len_q, len_k, head_dim] transpose(1, 2) # add directly into matmul_ac so we don't need to # torch.baddbmm(matmul_ac.transpose(0, 1), rr_head_q, pos.transpose(1, 2), # out=matmul_ac.transpose(0, 1), beta=1.0, alpha=scale_t[0]) matmul_ac.transpose(0, 1).baddbmm_(rr_head_q, pos.transpose(1, 2), beta=1.0, alpha=scale_t[0]) attn_score = matmul_ac # no need to shift in this case # attn_score should have size [bsz*heads, len_q, len_k] for now if mask is not None: attn_score.view(bsz, heads, len_q, len_k).masked_fill_(mask, float('-inf')) if not (mask_softmax_dropout_cuda is not None and len_k <= 2048 and attn_score.type() == 'torch.cuda.HalfTensor') or double_precision: dtype_ = torch.float64 if double_precision else torch.float32 softmax_results = F.softmax(attn_score, dim=-1).type_as(attn_score) # Dropout - is not executed for inference if is_training: dropout_results, dropout_mask = torch._fused_dropout( softmax_results, p=(1. - dropout_prob_t[0])) else: dropout_results = softmax_results dropout_mask = null_tensor ctx.fused_softmax_dropout = False else: # Fused Softmax and Dropout # ASSERTED To produce the same result with F.softmax dropout_mask, softmax_results, dropout_results = \ mask_softmax_dropout_cuda.forward(is_training, heads, attn_score, dropout_prob_t[0]) if not is_training: dropout_results = softmax_results # Verification # softmax_results_ref = F.softmax(attn_score, dim=-1) # if is_training: # dropout_results_ref = softmax_results_ref * dropout_mask.half() * (1 / (1 - dropout_prob_t[0])) # else: # dropout_results_ref = softmax_results_ref # # comp = torch.allclose(softmax_results_ref, softmax_results, rtol=1e-03, atol=1e-04) # comp = torch.allclose(dropout_results_ref, dropout_results, rtol=1e-03, atol=1e-04) # if comp: # print("Forward pass verification passed.") # else: # print("ERROR: Forward pass verification failed") # print(dropout_results - dropout_results_ref) # print(softmax_results) # Done Verification ctx.fused_softmax_dropout = True nan_mask = null_tensor # nan_mask = torch.isnan(softmax_results) # if nan_mask.any(): # softmax_results.masked_fill_(nan_mask, 0) # Matmul2 Batched GEMMs # Input1: from_softmax [bsz*heads, len_q, seql_k] # Input2: (values) [seql_v, bsz*heads, head_dim] transpose(0,1) # Output: [bsz*heads, len_q, head_dim] # GEMM: Per batch: ( len_q x seql_k ) x ( seql_k x head_dim ) = (len_q x head_dim) matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1)).transpose(0, 1) # if learnable_pos: # # Input1: from_softmax [bsz*heads, len_q, seql_k].transpose(0, 1) # # Input2: R [len_q, len_k, head_dim] # # Output: [ len_q, bsz*heads, head_dim] # torch.baddbmm(matmul2_results, dropout_results.transpose(0, 1), pos, beta=1.0, alpha=1.0, # out=matmul2_results) matmul2_results = matmul2_results.contiguous().view( inputs.size(0), inputs.size(1), inputs.size(2)) # Output Linear GEMM # Input1: (activations) [len_q, bsz, embed_dim=heads*head_dim] # Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1) # Output: [ len_q, bsz, embed_dim ] # GEMM: ( len_q*bsz x embed_dim ) x ( embed_dim x embed_dim ) = ( len_q*bsz x embed_dim ) outputs = torch.addmm(output_biases, matmul2_results.view( inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0, 1), beta=1., alpha=1.) outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0)) if recompute: ctx.save_for_backward(heads_t, scale_t, inputs, pos, r_head_k, input_weights, pos_weights, output_weights, input_biases, pos_biases, output_biases, r_w_bias, r_r_bias, dropout_mask, nan_mask, mask, dropout_prob_t) # delete stuff here del input_lin_results, queries, keys, values del matmul_ac, matmul2_results, attn_score, softmax_results, dropout_results del rr_head_q, rw_head_q if not learnable_pos: del matmul_bd dropout_results = null_tensor else: ctx.save_for_backward(heads_t, scale_t, matmul2_results, dropout_results, softmax_results, input_lin_results, pos_lin_results, rw_head_q, rr_head_q, inputs, pos, r_head_k, input_weights, pos_weights, output_weights, dropout_mask, nan_mask, dropout_prob_t) del attn_score if return_coverage: return (outputs, dropout_results) else: return outputs
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, mask, dropout_prob, incremental, incremental_cache, double_precision, return_coverage): heads_t = torch.tensor([heads]) dropout_prob_t = torch.tensor([dropout_prob]) null_tensor = torch.tensor([]) head_dim = inputs_q.size(2) // heads scale_t = torch.tensor([head_dim ** -0.5]) use_mask = (mask is not None) bsz, len_q, len_k = inputs_q.size(1), inputs_q.size(0), inputs_kv.size(0) ctx.incremental = incremental ctx.fused_softmax_dropout = False ctx.fused_all = False ctx.len_q = len_q ctx.len_k = len_k ctx.double_precision = double_precision ctx.return_coverage = return_coverage if mask is not None: # Self Attention Pad Mask mask = mask.to(torch.bool) if len(mask.shape) == 3: mask = mask.unsqueeze(1) # for the head dimension else: mask = mask.unsqueeze(1).unsqueeze(2) # for the head and query dimension if encdec_multihead_attn_cuda is not None and not incremental and len_k <= 2048\ and inputs_q.type() == 'torch.cuda.HalfTensor': input_lin_q_results, input_lin_kv_results, \ softmax_results, dropout_results, dropout_mask, \ matmul2_results, outputs \ = encdec_multihead_attn_cuda.forward(is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, mask, dropout_prob) ctx.save_for_backward(heads_t, scale_t, matmul2_results, dropout_results, softmax_results, input_lin_q_results, input_lin_kv_results, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, dropout_mask, dropout_prob_t) ctx.fused_all = True if return_coverage: return outputs, softmax_results else: return (outputs, ) # Input Linear GEMM Q # input1: (activations) [seql_q, bsz, embed_dim] -> [len_q * bsz, embed_dim] # input2: (weights) [embed_dim, embed_dim]. transpose(0, 1) # output: [len_q * bsz, embed_dim] -> [seql_q, bsz, embed_dim] # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0, 1)) input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0)) queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads, head_dim) # Input Linear GEMM KV # input1: (activations) [seql_k, bsz, embed_dim(1024)] # input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1]) # output: [seql_k, bsz, embed_dim*2] # GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2) # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Sequences and heads are combined to make the batch of the Batched GEMM if incremental and ('c_k' in incremental_cache and 'c_v' in incremental_cache): keys = incremental_cache['c_k'] values = incremental_cache['c_v'] keys = keys.view(len_k, bsz * heads, head_dim) values = values.view(len_k, bsz * heads, head_dim) input_lin_kv_results = torch.stack([keys, values], dim=-2) else: input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0, 1)) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0)) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads, 2, head_dim) keys = input_lin_kv_results[:, :, 0, :] values = input_lin_kv_results[:, :, 1, :] if incremental: keys = keys.contiguous().view(len_k, bsz, heads * head_dim) values = values.contiguous().view(len_k, bsz, heads * head_dim) incremental_cache['c_k'] = keys incremental_cache['c_v'] = values keys = keys.view(len_k, bsz * heads, head_dim) values = values.view(len_k, bsz * heads, head_dim) # Matmul1 Batched GEMMs # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of # a separate elementwise operation. # Input1: (Queries) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1) # output: [seqs*heads, seql_q, seql_k] # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) if queries.is_cuda: matmul1_results = torch.empty((queries.size(1), queries.size(0), keys.size(0)), dtype=queries.dtype, device=queries.device) matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0, 1), keys.transpose(0, 1).transpose(1, 2), out=matmul1_results, beta=0.0, alpha=scale_t[0]) else: matmul1_results = torch.matmul(queries.transpose(0, 1), keys.transpose(0, 1).transpose(1, 2)) matmul1_results.mul_(scale_t[0]) if mask is not None: batches, seql_q, seql_k = matmul1_results.size() bsz = int(batches / heads) matmul1_results = matmul1_results.view(bsz, heads, seql_q, seql_k) # after unsqueezing the mask should have size [bsz x 1 x 1 x seql_k] matmul1_results = matmul1_results.masked_fill_(mask, float('-inf')) matmul1_results = matmul1_results.view(bsz * heads, seql_q, seql_k) if mask_softmax_dropout_cuda and len_k <= 2048 \ and matmul1_results.type() == 'torch.cuda.HalfTensor' and not double_precision: # if False: # dropout_results_ref = F.softmax(matmul1_results, dim=-1) dropout_mask, softmax_results, dropout_results = mask_softmax_dropout_cuda.forward(is_training, heads, matmul1_results, dropout_prob_t[0]) if not is_training: dropout_results = softmax_results # because the cuda returns empty craps # Verification code # softmax_results_ref = F.softmax(matmul1_results, dim=-1) # # # if is_training: # # print(dropout_mask.float().sum(), dropout_mask.numel()) # dropout_results_ref = softmax_results_ref * dropout_mask.half() * (1 / (1 - dropout_prob_t[0])) # else: # dropout_results_ref = softmax_results_ref # # # comp = torch.allclose(softmax_results_ref, softmax_results, rtol=1e-03, atol=1e-04) # comp = torch.allclose(dropout_results_ref, dropout_results, rtol=1e-03, atol=1e-04) # if comp: # print("Forward pass verification passed.") # else: # print("ERROR: Forward pass verification failed") # Verification done ctx.fused_softmax_dropout = True else: # dtype_ = torch.float64 if double_precision else torch.float32 # softmax_results = F.softmax(matmul1_results, dim=-1, dtype=dtype_).type_as(matmul1_results) if matmul1_results.type() == 'torch.cuda.HalfTensor': softmax_results = F.softmax(matmul1_results, dim=-1, dtype=torch.float32).type_as(matmul1_results) else: softmax_results = F.softmax(matmul1_results, dim=-1) # Dropout - is not executed for inference if is_training: dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1. - dropout_prob_t[0])) else: dropout_results = softmax_results dropout_mask = null_tensor # Matmul2 Batched GEMMs # The output tensor specification is needed here to specify the non-standard output. # Given that pytorch cannot currently perform autograd with an output tensor specified, # this requires a backward pass specified. # Input1: from_softmax [seqs*heads, seql_q, seql_k] # Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1) # Output: [seql_q, seqs*heads, head_dim] transpose(0,1) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim) if queries.is_cuda: matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=dropout_results.device) torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results.transpose(1, 0)) else: matmul2_results = torch.matmul(dropout_results, values.transpose(0, 1)).transpose(0, 1) # view from [len_q, bsz*heads, head_dim] to [len_q, bsz, embed] matmul2_results = matmul2_results.contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2)) # Output Linear GEMM # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim] # Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1) # Output: [ seql_q, seqs, embed_dim ] # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0, 1)) outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0)) ctx.save_for_backward(heads_t, scale_t, matmul2_results, dropout_results, softmax_results, input_lin_q_results, input_lin_kv_results, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, dropout_mask, dropout_prob_t) if return_coverage: return (outputs, dropout_results) else: return (outputs, )