def forward(ctx, vocab_parallel_logits, target): # Copy so the input remains unchanged. logits = vocab_parallel_logits.clone() # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(logits, dim=-1)[0] all_reduce(logits_max, op='max', group=get_model_parallel_group()) # Subtract the maximum value. logits.sub_(logits_max.unsqueeze(dim=-1)) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = logits.exp() sum_exp_logits = exp_logits.sum(dim=-1) all_reduce(sum_exp_logits, op='sum', group=get_model_parallel_group()) # Get the partition's vocab indecies get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size partition_vocab_size = vocab_parallel_logits.size()[-1] rank = get_model_parallel_rank() world_size = get_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range( partition_vocab_size, rank, world_size) # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) masked_target = target.clone() - vocab_start_index masked_target[target_mask] = 0 # Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. all_reduce(predicted_logits, op='sum', group=get_model_parallel_group()) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits # Store softmax, target-mask and masked-target for backward pass. exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) return loss
def _reduce(input_): """All-reduce the the input tensor across model parallel group.""" group = get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if get_world_size(group=group) == 1: return input_ # All-reduce. all_reduce(input_, group=group) return input_
def _gather(input_): """Gather tensors and concatinate along the last dimension.""" group = get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if get_world_size(group=group) == 1: return input_ # Size and dimension. last_dim = input_.dim() - 1 tensor_list = all_gather(tensor=input_, group=group) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() return output
def _split(input_): """Split the tensor along its last dimension and keep the corresponding slice.""" group = get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if get_world_size(group=group) == 1: return input_ # Split along last dimension. world_size = get_world_size(group=group) input_list = split_tensor_along_last_dim(input_, world_size) # Note: torch.split does not create contiguous tensors by default. rank = get_rank(group=group) output = input_list[rank].contiguous() return output
def _aggregate_model_parallel_grad_norm(total_norm): total_norm = total_norm**2 distributed_utils.all_reduce( total_norm, group=distributed_utils.get_model_parallel_group()) total_norm = total_norm**0.5 return total_norm