def initialize_word_embeddings(self, init_method_normal): args = get_args() if not self.share_word_embeddings: raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') # Parameters are shared between the word embeddings layer, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_first_stage(): self._word_embeddings_for_head_key = 'word_embeddings_for_head' # If first and last stages are different, set word_embeddings # weights to 0 here, then copy first stage's weights using # all_reduce below. self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Ensure that first and last stages have the same initial parameter # values. if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): torch.distributed.all_reduce(self.word_embeddings_weight().data, group=mpu.get_embedding_group())
def forward_step(batch, model, eval_metric): """Forward step.""" # Get the batch. tokens, labels, attention_mask, position_ids, loss_mask = process_batch( batch) # Tell the model what our actual batch size will be args = get_args() args.micro_batch_size = len(labels) # Forward model. if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate(tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) else: input_tensor = None # Forward pass through the model. if mpu.is_pipeline_first_stage(): assert input_tensor is None if mpu.is_pipeline_last_stage(): output = model(tokens, position_ids, attention_mask) else: output = model(tokens, position_ids, attention_mask) else: assert input_tensor is not None output = model(input_tensor, attention_mask) if not mpu.is_pipeline_last_stage(): communicate(tensor_send_next=output, tensor_send_prev=None, recv_forward=False, recv_backward=False) return None if mpu.is_pipeline_last_stage(): # For loss, return the unreduced loss. if eval_metric == 'loss': losses = mpu.vocab_parallel_cross_entropy( output.contiguous().float(), labels.contiguous()) loss = torch.sum( losses.view(-1) * loss_mask.contiguous().view(-1).float()) return loss # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': outputs = torch.argmax(output, -1) correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 correct = correct.prod(-1) return correct.sum() raise NotImplementedError('forward method for evaluation metric {} ' 'is not implemented.'.format(eval_metric)) return None
def initialize_word_embeddings(self, init_method_normal): args = get_args() if not self.share_word_embeddings: raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline # parallelism there is nothing to do. if args.pipeline_model_parallel_size == 1: return # Parameters are shared between the word embeddings layer, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if mpu.is_pipeline_last_stage(): assert not mpu.is_pipeline_first_stage() self._word_embeddings_for_head_key = 'word_embeddings_for_head' # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Ensure that first and last stages have the same initial parameter # values. if torch.distributed.is_initialized(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): torch.distributed.all_reduce( self.word_embeddings_weight().data, group=mpu.get_embedding_group()) else: print("WARNING! Distributed processes aren't initialized, so " "word embeddings in the last layer are not initialized. " "If you are just manipulating a model this is fine, but " "this needs to be handled manually. If you are training " "something is definitely wrong.")
def forward_step_with_communication(forward_step_func, data_iterator, model, input_tensors, output_tensors, losses_reduced, timers): args = get_args() if not mpu.is_pipeline_first_stage(): timers('forward-recv').start() input_tensor, _ = communicate(tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) timers('forward-recv').stop() else: input_tensor = None # Forward model for one step. timers('forward-compute').start() output_tensor = forward_step_func(data_iterator, model, input_tensor) timers('forward-compute').stop() if mpu.is_pipeline_last_stage(): loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() losses_reduced.append(loss_reduced) else: timers('forward-send').start() communicate(tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=False) timers('forward-send').stop() input_tensors.append(input_tensor) output_tensors.append(output_tensor)
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) if mpu.is_pipeline_last_stage(): output_tensor_grad = None else: timers('backward-recv').start() _, output_tensor_grad = communicate(tensor_send_next=None, tensor_send_prev=None, recv_forward=False, recv_backward=True) timers('backward-recv').stop() # Backward pass for one step. timers('backward-compute').start() input_grad_tensor = \ backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad) timers('backward-compute').stop() if not mpu.is_pipeline_first_stage(): timers('backward-send').start() communicate(tensor_send_next=None, tensor_send_prev=input_grad_tensor, recv_forward=False, recv_backward=False) timers('backward-send').stop()
def model_provider(): """Build the model.""" if eval_metric == 'loss': parallel_output = True elif eval_metric == 'accuracy': parallel_output = False else: raise NotImplementedError('output type for {} evaluation metric ' 'is not supported.'.format(eval_metric)) print_rank_0('building GPT2 model ...') if mpu.get_pipeline_model_parallel_world_size() > 1: # Determine model based on position of stage in pipeline. if mpu.is_pipeline_first_stage(): model = GPT2ModelFirstStage(num_tokentypes=0) elif mpu.is_pipeline_last_stage(): model = GPT2ModelLastStage(parallel_output=parallel_output, num_tokentypes=0) else: model = GPT2ModelIntermediateStage(num_tokentypes=0) else: model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output) return model
def forward_step_helper(microbatch_id): """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).""" model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # forward step if mpu.is_pipeline_first_stage(): if len(input_tensors[model_chunk_id]) == \ len(output_tensors[model_chunk_id]): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] output_tensor = forward_step(forward_step_func, data_iterator[model_chunk_id], model[model_chunk_id], input_tensor, losses_reduced) output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: input_tensors[model_chunk_id].pop() output_tensors[model_chunk_id].pop() return output_tensor
def forward(self, *inputs, **kwargs): if mpu.is_pipeline_first_stage(): inputs = fp32_to_fp16(inputs) outputs = self.module(*inputs, **kwargs) if mpu.is_pipeline_last_stage(): outputs = fp16_to_fp32(outputs) return outputs
def forward(self, model_input, attention_mask, tokentype_ids=None): extended_attention_mask = bert_extended_attention_mask(attention_mask) kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = model_input position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage(): _, pooled_output = lm_output classification_output = self.classification_dropout(pooled_output) classification_logits = self.classification_head( classification_output) # Reshape back to separate choices. classification_logits = classification_logits.view( -1, self.num_classes) return classification_logits return lm_output
def __init__(self, attention_mask_func, init_method, output_layer_init_method, num_tokentypes=0, add_pooler=False): super(TransformerLanguageModelBase, self).__init__() args = get_args() self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_pooler = add_pooler # Embeddings. if mpu.is_pipeline_first_stage(): self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) self._embedding_key = 'embedding' # Transformer. self.transformer = ParallelTransformer(attention_mask_func, self.init_method, output_layer_init_method) self._transformer_key = 'transformer' # Pooler. if mpu.is_pipeline_last_stage() and self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler'
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return tensor # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: if is_last_stage: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor, src, group) else: tensor = None return tensor
def forward(self, language_model_input, attention_mask, tokentype_ids=None, layer_past=None, get_key_value=False, pooling_sequence_index=0): # Embeddings. if mpu.is_pipeline_first_stage(): (input_ids, position_ids) = language_model_input embedding_output = self.embedding(input_ids, position_ids, tokentype_ids=tokentype_ids) transformer_input = embedding_output else: transformer_input = language_model_input # Transformer. transformer_output = self.transformer(transformer_input, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if mpu.is_pipeline_last_stage() and self.add_pooler: pooled_output = self.pooler(transformer_output, pooling_sequence_index) return transformer_output, pooled_output return transformer_output
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Copy tensor values from last stage into the first stage. Note that the input tensor is updated in place.""" is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: _is_cuda(tensor) is_contiguous = tensor.is_contiguous() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() if is_contiguous: tensor_ = tensor else: if is_last_stage: tensor_ = tensor.contiguous() else: tensor_ = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor_, src, group) # Update the first stage tensor if is_first_stage and not is_contiguous: tensor[...] = tensor_
def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Embedding. if mpu.is_pipeline_first_stage(): if self._embedding_key in state_dict: state_dict_ = state_dict[self._embedding_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if '_embeddings' in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) # Transformer. if self._transformer_key in state_dict: state_dict_ = state_dict[self._transformer_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'transformer.' in key: state_dict_[key.split('transformer.')[1]] = state_dict[key] self.transformer.load_state_dict(state_dict_, strict=strict) # Pooler. if mpu.is_pipeline_last_stage() and self.add_pooler: assert 'pooler' in state_dict, \ 'could not find data for pooler in the checkpoint' self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
def forward(self, bert_model_input, attention_mask, tokentype_ids=None, lm_labels=None, position_ids=None): extended_attention_mask = bert_extended_attention_mask(attention_mask) if attention_mask.dim() == 2 else attention_mask kwargs = {} if mpu.is_pipeline_first_stage(): input_ids = bert_model_input if position_ids is None: position_ids = bert_position_ids(input_ids) args = [input_ids, position_ids, extended_attention_mask] kwargs['tokentype_ids'] = tokentype_ids else: args = [bert_model_input, extended_attention_mask] lm_output = self.language_model(*args, **kwargs) if mpu.is_pipeline_last_stage() and self.add_binary_head: lm_output, pooled_output = lm_output else: pooled_output = None if mpu.is_pipeline_last_stage(): return post_language_model_processing(lm_output, pooled_output, self.lm_head, self.binary_head, lm_labels, self.word_embeddings_weight(), self.fp16_lm_cross_entropy) else: return lm_output
def _cross_entropy_forward_step(batch, model, input_tensor): """Simple forward step with cross-entropy loss.""" timers = get_timers() # Get the batch. timers('batch-generator').start() try: batch_ = next(batch) except BaseException: batch_ = batch tokens, types, labels, attention_mask = process_batch(batch_) timers('batch-generator').stop() # Forward model. if mpu.is_pipeline_first_stage(): assert input_tensor is None output_tensor = model(tokens, attention_mask, tokentype_ids=types) else: assert input_tensor is not None output_tensor = model(input_tensor, attention_mask) if mpu.is_pipeline_last_stage(): logits = output_tensor # Cross-entropy loss. loss_func = torch.nn.CrossEntropyLoss() loss = loss_func(logits.contiguous().float(), labels) # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) return loss, {'lm loss': averaged_loss[0]} return output_tensor
def _allocate_recv_buffer(batch_size, sequence_length): """Receive happens between the layers with size [s, b, h].""" if mpu.is_pipeline_first_stage(): return None args = get_args() recv_size = (sequence_length, batch_size, args.hidden_size) return torch.empty(recv_size, dtype=_get_recv_buffer_dtype(args), device=torch.cuda.current_device())
def forward_backward_pipelining(forward_step_func, data_iterator, model, optimizer, timers): """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed.""" args = get_args() # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = \ (mpu.get_pipeline_model_parallel_world_size() - mpu.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min( num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches input_tensors = [] output_tensors = [] losses_reduced = [] # Run warmup forward passes. for i in range(num_warmup_microbatches): forward_step_with_communication( forward_step_func, data_iterator, model, input_tensors, output_tensors, losses_reduced, timers) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: if mpu.is_pipeline_first_stage(): input_tensor = None else: timers('forward-recv').start() input_tensor, _ = communicate(tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) timers('forward-recv').stop() # Run 1F1B. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) input_tensor = \ forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model, optimizer, input_tensor, last_iteration, input_tensors, output_tensors, losses_reduced, timers) # Run cooldown backward passes. for i in range(num_warmup_microbatches): backward_step_with_communication( optimizer, model, input_tensors, output_tensors, timers) return losses_reduced
def word_embeddings_weight(self): if mpu.is_pipeline_first_stage(): return self.language_model.embedding.word_embeddings.weight if mpu.is_pipeline_last_stage(): if not self.share_word_embeddings: raise Exception('word_embeddings_weight() called for last ' 'stage, but share_word_embeddings is false') return self.word_embeddings.weight raise Exception('word_embeddings_weight() should be ' 'called for first and last stage only')
def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Load word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): self.word_embeddings.load_state_dict( state_dict[self._word_embeddings_for_head_key], strict=strict) if self._language_model_key in state_dict: state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict)
def send_backward(input_tensor_grad, timers=None): """Send tensor to previous rank in pipeline (backward send).""" if not mpu.is_pipeline_first_stage(): if timers is not None: timers('backward-send').start() _communicate(tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=False) if timers is not None: timers('backward-send').stop()
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids \ = get_batch(data_iterator) timers('batch-generator').stop() extended_attention_mask = bert_extended_attention_mask( padding_mask) + attention_mask # Forward pass through the model. if mpu.is_pipeline_first_stage(): assert input_tensor is None if mpu.is_pipeline_last_stage(): output_tensor = model(tokens, extended_attention_mask, tokentype_ids=None, lm_labels=lm_labels, position_ids=position_ids) else: output_tensor = model(tokens, extended_attention_mask, tokentype_ids=None) elif mpu.is_pipeline_last_stage(): assert input_tensor is not None output_tensor = model(input_tensor, extended_attention_mask, lm_labels=lm_labels) else: assert input_tensor is not None output_tensor = model(input_tensor, extended_attention_mask, position_ids=position_ids) if mpu.is_pipeline_last_stage(): lm_loss_, _ = output_tensor lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss = lm_loss averaged_losses = average_losses_across_data_parallel_group([ lm_loss, ]) return loss, {'lm loss': averaged_losses[0]} return output_tensor
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): state_dict_ = {} state_dict_[self._language_model_key] \ = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars) # Save word_embeddings. if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): state_dict_[self._word_embeddings_for_head_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_
def evaluate(forward_step_func, data_iterator, model, verbose=False): """Evaluation.""" args = get_args() # Turn on evaluation mode which disables dropout. model.eval() total_loss_dict = {} with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) for _ in range(get_num_microbatches()): if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate( tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) else: input_tensor = None # Forward evaluation. output_tensor = forward_step_func(data_iterator, model, input_tensor) if mpu.is_pipeline_last_stage(): _, loss_dict = output_tensor # Reduce across processes. for key in loss_dict: total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \ loss_dict[key] else: communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=False) args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ * args.micro_batch_size \ * get_num_microbatches() # Move model back to the train mode. model.train() for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters * get_num_microbatches() return total_loss_dict
def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # Checks. if layer_past is not None: assert get_key_value, \ 'for not None values in layer_past, ' \ 'expected get_key_value to be set' if get_key_value: assert not self.checkpoint_activations, \ 'get_key_value does not work with ' \ 'activation checkpointing' if mpu.is_pipeline_first_stage(): # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # If the input flag for fp32 residual connection is set, convert for float. if self.fp32_residual_connection: hidden_states = hidden_states.transpose(0, 1).contiguous().float() # Otherwise, leave it as is. else: hidden_states = hidden_states.transpose(0, 1).contiguous() if self.checkpoint_activations: hidden_states = self._checkpointed_forward(hidden_states, attention_mask) else: if get_key_value: presents = [] for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[index] hidden_states = layer(hidden_states, attention_mask, layer_past=past, get_key_value=get_key_value) if get_key_value: hidden_states, present = hidden_states presents.append(present) # Final layer norm. if mpu.is_pipeline_last_stage(): # Reverting data format change [s b h] --> [b s h]. hidden_states = hidden_states.transpose(0, 1).contiguous() output = self.final_layernorm(hidden_states) else: output = hidden_states if get_key_value: output = [output, presents] return output
def recv_from_prev_pipeline_rank_(recv_buffer=None): """Receive from previous pipeline stage and update the input buffer inplace.""" if not mpu.is_pipeline_first_stage(): assert recv_buffer is not None recv_prev_op = torch.distributed.P2POp( torch.distributed.irecv, recv_buffer, mpu.get_pipeline_model_parallel_prev_rank()) reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize()
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ = get_batch(data_iterator) timers('batch-generator').stop() # Forward pass through the model. if mpu.is_pipeline_first_stage(): assert input_tensor is None if mpu.is_pipeline_last_stage(): output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) else: output_tensor = model(tokens, padding_mask, tokentype_ids=types) elif mpu.is_pipeline_last_stage(): assert input_tensor is not None output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels) else: assert input_tensor is not None output_tensor = model(input_tensor, padding_mask) if mpu.is_pipeline_last_stage(): lm_loss_, sop_logits = output_tensor sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) sop_loss = sop_loss.float() lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss = lm_loss + sop_loss averaged_losses = average_losses_across_data_parallel_group( [lm_loss, sop_loss]) return loss, { 'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1] } return output_tensor
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, init_method=None, scaled_init_method=None): """Build language model and return along with the key to save.""" args = get_args() if init_method is None: init_method = init_method_normal(args.init_method_std) if scaled_init_method is None: scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) # Language model. args = [attention_mask_func, init_method, scaled_init_method] kwargs = {} cls = None if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): cls = TransformerLanguageModel kwargs['num_tokentypes'] = num_tokentypes kwargs['add_pooler'] = add_pooler elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage(): cls = TransformerLanguageModelFirstStage kwargs['num_tokentypes'] = num_tokentypes elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): cls = TransformerLanguageModelLastStage kwargs['add_pooler'] = add_pooler else: cls = TransformerLanguageModelIntermediateStage # Language model. language_model = cls(*args, **kwargs) # key used for checkpoints. language_model_key = 'language_model' return language_model, language_model_key
def send_backward_recv_forward(input_tensor_grad, timers=None): """Batched send and recv with previous rank in pipeline.""" if mpu.is_pipeline_first_stage(): input_tensor = None else: if timers is not None: timers('backward-send-forward-recv').start() input_tensor, _ = _communicate(tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, recv_next=False) if timers is not None: timers('backward-send-forward-recv').stop() return input_tensor
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model, optimizer, input_tensor, last_microbatch, input_tensors, output_tensors, losses_reduced, timers): args = get_args() # Forward model for one step. timers('forward-compute').start() output_tensor = forward_step_func(data_iterator, model, input_tensor) timers('forward-compute').stop() if mpu.is_pipeline_last_stage(): loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() output_tensor_grad = None losses_reduced.append(loss_reduced) else: timers('forward-send-backward-recv').start() _, output_tensor_grad = communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=True) timers('forward-send-backward-recv').stop() input_tensors.append(input_tensor) output_tensors.append(output_tensor) input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) # Backward pass for one step. timers('backward-compute').start() input_grad_tensor = \ backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad) timers('backward-compute').stop() if not mpu.is_pipeline_first_stage(): timers('backward-send-forward-recv').start() input_tensor, _ = communicate( tensor_send_next=None, tensor_send_prev=input_grad_tensor, recv_forward=(not last_microbatch), recv_backward=False) timers('backward-send-forward-recv').stop() else: input_tensor = None return input_tensor