def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return tensor # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: if is_last_stage: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor, src, group) else: tensor = None return tensor
def initialize_word_embeddings(self, init_method_normal): args = get_args() if not self.share_word_embeddings: raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') # Parameters are shared between the word embeddings layer, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_first_stage(): self._word_embeddings_for_head_key = 'word_embeddings_for_head' # If first and last stages are different, set word_embeddings # weights to 0 here, then copy first stage's weights using # all_reduce below. self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Ensure that first and last stages have the same initial parameter # values. if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): torch.distributed.all_reduce(self.word_embeddings_weight().data, group=mpu.get_embedding_group())
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Copy tensor values from last stage into the first stage. Note that the input tensor is updated in place.""" is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: _is_cuda(tensor) is_contiguous = tensor.is_contiguous() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() if is_contiguous: tensor_ = tensor else: if is_last_stage: tensor_ = tensor.contiguous() else: tensor_ = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor_, src, group) # Update the first stage tensor if is_first_stage and not is_contiguous: tensor[...] = tensor_
def initialize_word_embeddings(self, init_method_normal): args = get_args() if not self.share_word_embeddings: raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline # parallelism there is nothing to do. if args.pipeline_model_parallel_size == 1: return # Parameters are shared between the word embeddings layer, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if mpu.is_pipeline_last_stage(): assert not mpu.is_pipeline_first_stage() self._word_embeddings_for_head_key = 'word_embeddings_for_head' # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Ensure that first and last stages have the same initial parameter # values. if torch.distributed.is_initialized(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): torch.distributed.all_reduce( self.word_embeddings_weight().data, group=mpu.get_embedding_group()) else: print("WARNING! Distributed processes aren't initialized, so " "word embeddings in the last layer are not initialized. " "If you are just manipulating a model this is fine, but " "this needs to be handled manually. If you are training " "something is definitely wrong.")
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen=None, type_ids=None): args = get_args() tokenizer = get_tokenizer() model.eval() with torch.no_grad(): context_length = context_lengths.min().item() eos_id = tokenizer.eod counter = 0 org_context_length = context_length layer_past = None batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens if maxlen is None: maxlen = args.seq_length - 1 if maxlen > (org_context_length + args.out_seq_length): maxlen = org_context_length + args.out_seq_length lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length <= (maxlen): if args.recompute: output = forward_step(model, tokens, position_ids, attention_mask, tokentype_ids=type_ids, forward_method_parallel_output=False) if mpu.is_pipeline_last_stage(): assert output is not None logits = output[:, context_length - 1, :] else: types2use = None if counter == 0: tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] if type_ids is not None: types2use = type_ids[:, :context_length] else: tokens2use = tokens[:, context_length - 1].view( batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) output, layer_past = forward_step( model, tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use, forward_method_parallel_output=False) if mpu.is_pipeline_last_stage(): assert output is not None logits = output[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length new_tokens = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eos_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) yield tokens, lengths else: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None else: yield None, None done = torch.cuda.ByteTensor([0]) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. optimizer.zero_grad() if mpu.get_pipeline_model_parallel_world_size() > 1: losses_reduced = forward_backward_pipelining(forward_step_func, data_iterator, model, optimizer, timers) else: losses_reduced = forward_backward_no_pipelining( forward_step_func, data_iterator, model, optimizer, timers) # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() model.allreduce_params(reduce_after=False, fp32_allreduce=args.fp32_allreduce) timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \ mpu.get_pipeline_model_parallel_world_size() > 1: unwrapped_model = model while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() torch.distributed.all_reduce(word_embeddings_weight.grad, group=mpu.get_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successfull = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successfull: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 if mpu.is_pipeline_last_stage(): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter return {}, skipped_iter
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp: for partition in model: partition.zero_grad_buffer() optimizer.zero_grad() forward_backward_func = get_forward_backward_func() losses_reduced = forward_backward_func(forward_step_func, data_iterator, model, optimizer, timers, forward_only=False) # Empty unused memory if args.empty_unused_memory_level >= 1: torch.cuda.empty_cache() # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() for model_module in model: model_module.allreduce_gradients() timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.is_pipeline_first_stage(ignore_virtual=True): unwrapped_model = model[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = model[-1] else: # We do not support the interleaved schedule for T5 yet. unwrapped_model = model[0] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() if args.DDP_impl == 'local': grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) # All-reduce position_embeddings grad across first (encoder) and split (decoder) # stages to ensure that position embeddings parameters stay in sync. # This should only run for T5 models with pipeline parallelism if mpu.is_rank_in_position_embedding_group() and \ mpu.get_pipeline_model_parallel_world_size() > 1 and \ args.pipeline_model_parallel_split_rank is not None: unwrapped_model = model[0] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) assert args.DDP_impl == 'local', \ 'T5 model is only supported with local DDP mode' grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successful: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 # Empty unused memory if args.empty_unused_memory_level >= 2: torch.cuda.empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad return {}, skipped_iter, grad_norm, num_zeros_in_grad
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp: for partition in model: partition.zero_grad_buffer() else: optimizer.zero_grad() forward_backward_func = get_forward_backward_func() losses_reduced = forward_backward_func(forward_step_func, data_iterator, model, optimizer, timers, forward_only=False) # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() for model_module in model: model_module.allreduce_gradients() timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if (mpu.is_pipeline_first_stage(ignore_virtual=True) or mpu.is_pipeline_last_stage(ignore_virtual=True)) and \ mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.is_pipeline_first_stage(ignore_virtual=True): unwrapped_model = model[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = model[-1] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() if args.DDP_impl == 'local': grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successful: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad return {}, skipped_iter, grad_norm, num_zeros_in_grad