def __init__(self, attention_mask_func, init_method, output_layer_init_method): super(ParallelTransformer, self).__init__() args = get_args() self.fp32_residual_connection = args.fp32_residual_connection # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers # Number of layers. assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ 'num_layers must be divisible by pipeline_model_parallel_size' self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size() # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( attention_mask_func, init_method, output_layer_init_method, layer_number) offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers self.layers = torch.nn.ModuleList( [build_layer(i + 1 + offset) for i in range(self.num_layers)]) if mpu.is_pipeline_last_stage(): # Final layer norm before output. LayerNorm = import_layernorm(args.fp32_residual_connection) self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon)
def model_provider(): """Build the model.""" if eval_metric == 'loss': parallel_output = True elif eval_metric == 'accuracy': parallel_output = False else: raise NotImplementedError('output type for {} evaluation metric ' 'is not supported.'.format(eval_metric)) print_rank_0('building GPT2 model ...') if mpu.get_pipeline_model_parallel_world_size() > 1: # Determine model based on position of stage in pipeline. if mpu.is_pipeline_first_stage(): model = GPT2ModelFirstStage(num_tokentypes=0) elif mpu.is_pipeline_last_stage(): model = GPT2ModelLastStage(parallel_output=parallel_output, num_tokentypes=0) else: model = GPT2ModelIntermediateStage(num_tokentypes=0) else: model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output) return model
def biencoder_model_provider(only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, post_process=True): """Build the model.""" assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" print_rank_0('building BiEncoderModel...') # simpler to just keep using 2 tokentypes since # the LM we initialize with has 2 tokentypes model = BiEncoderModel( num_tokentypes=2, parallel_output=False, only_query_model=only_query_model, only_context_model=only_context_model, biencoder_shared_query_context_model=\ biencoder_shared_query_context_model, pre_process=pre_process, post_process=post_process) return model
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss with respect to stage's output tensor. Returns gradient of loss with respect to input tensor (None if first stage).""" # NOTE: This code currently can handle at most one skip connection. It # needs to be modified slightly to support arbitrary numbers of skip # connections. args = get_args() timers = get_timers() timers('backward-compute').start() # Retain the grad on the input_tensor. unwrap_input_tensor_grad = False if not isinstance(input_tensor, list): input_tensor = [input_tensor] unwrap_input_tensor_grad = True for x in input_tensor: if x is not None: x.retain_grad() if not isinstance(output_tensor, list): output_tensor = [output_tensor] if not isinstance(output_tensor_grad, list): output_tensor_grad = [output_tensor_grad] # Backward pass. if output_tensor_grad[0] is None: output_tensor = optimizer.scale_loss(output_tensor[0]) custom_backward(output_tensor[0], output_tensor_grad[0]) # Collect the grad of the input_tensor. input_tensor_grad = [None] if input_tensor is not None: input_tensor_grad = [] for x in input_tensor: if x is None: input_tensor_grad.append(None) else: input_tensor_grad.append(x.grad) # Handle single skip connection if it exists (encoder_hidden_state in # model with encoder and decoder). if mpu.get_pipeline_model_parallel_world_size() > 1 and \ mpu.is_pipeline_stage_after_split() and \ args.model_type == ModelType.encoder_and_decoder: if output_tensor_grad[1] is not None: input_tensor_grad[-1].add_(output_tensor_grad[1]) if unwrap_input_tensor_grad: input_tensor_grad = input_tensor_grad[0] timers('backward-compute').stop() return input_tensor_grad
def forward_backward_pipelining(forward_step_func, data_iterator, model, optimizer, timers): """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed.""" args = get_args() # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = \ (mpu.get_pipeline_model_parallel_world_size() - mpu.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min( num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches input_tensors = [] output_tensors = [] losses_reduced = [] # Run warmup forward passes. for i in range(num_warmup_microbatches): forward_step_with_communication( forward_step_func, data_iterator, model, input_tensors, output_tensors, losses_reduced, timers) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: if mpu.is_pipeline_first_stage(): input_tensor = None else: timers('forward-recv').start() input_tensor, _ = communicate(tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) timers('forward-recv').stop() # Run 1F1B. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) input_tensor = \ forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model, optimizer, input_tensor, last_iteration, input_tensors, output_tensors, losses_reduced, timers) # Run cooldown backward passes. for i in range(num_warmup_microbatches): backward_step_with_communication( optimizer, model, input_tensors, output_tensors, timers) return losses_reduced
def get_forward_backward_func(): args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: if args.virtual_pipeline_model_parallel_size is not None: forward_backward_func = forward_backward_pipelining_with_interleaving else: forward_backward_func = forward_backward_pipelining_without_interleaving else: forward_backward_func = forward_backward_no_pipelining return forward_backward_func
def get_forward_backward_func(): args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: if args.virtual_pipeline_model_parallel_size is not None: forward_backward_func = forward_backward_pipelining_with_interleaving assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \ 'number of microbatches is not divisible by pipeline-parallel ' \ 'size when using interleaved schedule' else: forward_backward_func = forward_backward_pipelining_without_interleaving else: forward_backward_func = forward_backward_no_pipelining return forward_backward_func
def model_provider(): """Build the model.""" print_rank_0('building multichoice model for RACE ...') if mpu.get_pipeline_model_parallel_world_size() > 1: # Determine model based on position of stage in pipeline. if mpu.is_pipeline_first_stage(): model = MultipleChoiceFirstStage(num_tokentypes=2) elif mpu.is_pipeline_last_stage(): model = MultipleChoiceLastStage(num_tokentypes=2) else: model = MultipleChoiceIntermediateStage(num_tokentypes=2) else: model = MultipleChoice(num_tokentypes=2) return model
def evaluate(forward_step_func, data_iterator, model, verbose=False): """Evaluation.""" args = get_args() # Turn on evaluation mode which disables dropout. for model_module in model: model_module.eval() total_loss_dict = {} with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) if mpu.get_pipeline_model_parallel_world_size() > 1: if args.virtual_pipeline_model_parallel_size is not None: forward_backward_func = forward_backward_pipelining_with_interleaving else: forward_backward_func = forward_backward_pipelining_without_interleaving else: forward_backward_func = forward_backward_no_pipelining loss_dicts = forward_backward_func( forward_step_func, data_iterator, model, optimizer=None, timers=None, forward_only=True) if mpu.is_pipeline_last_stage(ignore_virtual=True): # Reduce across processes. for loss_dict in loss_dicts: for key in loss_dict: total_loss_dict[key] = total_loss_dict.get( key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ * args.micro_batch_size \ * get_num_microbatches() # Move model back to the train mode. for model_module in model: model_module.train() for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters * get_num_microbatches() return total_loss_dict
def model_provider(): """Build the model.""" print_rank_0('building GPT model ...') args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: # Determine model based on position of stage in pipeline. if mpu.is_pipeline_first_stage(): model = GPTModelFirstStage(num_tokentypes=0) elif mpu.is_pipeline_last_stage(): model = GPTModelLastStage(num_tokentypes=0, parallel_output=True) else: model = GPTModelIntermediateStage(num_tokentypes=0) else: model = GPTModel(num_tokentypes=0, parallel_output=True) return model
def get_checkpoint_name(checkpoints_path, iteration, release=False): """A unified checkpoint name.""" if release: directory = 'release' else: directory = 'iter_{:07d}'.format(iteration) # Use both the tensor and pipeline MP rank. if mpu.get_pipeline_model_parallel_world_size() == 1: return os.path.join( checkpoints_path, directory, 'mp_rank_{:02d}'.format(mpu.get_tensor_model_parallel_rank()), 'model_optim_rng.pt') return os.path.join( checkpoints_path, directory, 'mp_rank_{:02d}_{:03d}'.format(mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank()), 'model_optim_rng.pt')
def general_ict_model_provider(only_query_model=False, only_block_model=False): """Build the model.""" args = get_args() assert args.ict_head_size is not None, \ "Need to specify --ict-head-size to provide an ICTBertModel" assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" print_rank_0('building ICTBertModel...') # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes model = ICTBertModel(ict_head_size=args.ict_head_size, num_tokentypes=2, parallel_output=True, only_query_model=only_query_model, only_block_model=only_block_model) return model
def model_provider(): """Build the model.""" args = get_args() print_rank_0('building classification model for {} ...'.format( args.task)) if mpu.get_pipeline_model_parallel_world_size() > 1: # Determine model based on position of stage in pipeline. if mpu.is_pipeline_first_stage(): model = ClassificationFirstStage(num_classes=num_classes, num_tokentypes=2) elif mpu.is_pipeline_last_stage(): model = ClassificationLastStage(num_classes=num_classes, num_tokentypes=2) else: model = ClassificationIntermediateStage( num_classes=num_classes, num_tokentypes=2) else: model = Classification(num_classes=num_classes, num_tokentypes=2) return model
def setup_model_and_optimizer(model_provider_func): """Setup model and optimizer.""" args = get_args() model = get_model(model_provider_func) unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) optimizer = get_megatron_optimizer(unwrapped_model) lr_scheduler = get_learning_rate_scheduler(optimizer) if args.load is not None: timers = get_timers() # Extra barrier is added to make sure all ranks report the # max time. torch.distributed.barrier() timers('load-checkpoint').start() args.iteration = load_checkpoint(model, optimizer, lr_scheduler) torch.distributed.barrier() timers('load-checkpoint').stop() timers.log(['load-checkpoint']) else: args.iteration = 0 # We only support local DDP with multiple micro-batches. if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1: assert args.DDP_impl == 'local' # get model without FP16 and/or TorchDDP wrappers if args.iteration == 0 and len(unwrapped_model) == 1 \ and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): print_rank_0("Initializing ICT from pretrained BERT model") unwrapped_model[0].init_state_dict_from_bert() if args.fp16: optimizer.reload_model_params() return model, optimizer, lr_scheduler
def get_fmoe_checkpoint_name(checkpoints_path, iteration, release=False, data_parallel_rank=-1): """A unified checkpoint name, allowing specifying a data parallel rank""" from megatron import mpu from megatron.checkpointing import get_checkpoint_name if data_parallel_rank == -1: data_parallel_rank = mpu.get_data_parallel_rank() if data_parallel_rank == 0: return get_checkpoint_name(checkpoints_path, iteration, release) if release: directory = "release" else: directory = "iter_{:07d}".format(iteration) # Use both the tensor and pipeline MP rank. if mpu.get_pipeline_model_parallel_world_size() == 1: return os.path.join( checkpoints_path, directory, "mp_rank_{:02d}_dp_rank_{:04d}".format( mpu.get_tensor_model_parallel_rank(), data_parallel_rank), "model_optim_rng.pt", ) return os.path.join( checkpoints_path, directory, "mp_rank_{:02d}_{:03d}_dp_rank_{:04d}".format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(), data_parallel_rank, ), "model_optim_rng.pt", )
def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model. if mpu.get_pipeline_model_parallel_world_size() > 1 and \ args.virtual_pipeline_model_parallel_size is not None: model = [] for i in range(args.virtual_pipeline_model_parallel_size): mpu.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() this_model = model_provider_func(pre_process=pre_process, post_process=post_process) model.append(this_model) else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() model = model_provider_func(pre_process=pre_process, post_process=post_process) if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on (tensor, pipeline) ' 'model parallel rank ({}, {}): {}'.format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(), sum([ sum([p.nelement() for p in model_module.parameters()]) for model_module in model ])), flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16 or args.bf16: model = [Float16Module(model_module, args) for model_module in model] if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = [ torchDDP(model_module, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) for model_module in model ] return model if args.DDP_impl == 'local': model = [ LocalDDP(model_module, args.accumulate_allreduce_grads_in_fp32, args.use_contiguous_buffers_in_ddp) for model_module in model ] return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl))
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.""" args = get_args() timers = get_timers() assert len(model) == 1 model = model[0] # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = \ (mpu.get_pipeline_model_parallel_world_size() - mpu.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) model_type = unwrapped_model.model_type rank = mpu.get_pipeline_model_parallel_rank() recv_tensor_shapes = get_tensor_shapes(rank - 1, model_type) send_tensor_shapes = get_tensor_shapes(rank, model_type) # Input, output tensors only need to be saved when doing backward passes input_tensors = None output_tensors = None if not forward_only: input_tensors = [] output_tensors = [] losses_reduced = [] # Run warmup forward passes. for i in range(num_warmup_microbatches): input_tensor = recv_forward(recv_tensor_shapes, timers=timers) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) send_forward(output_tensor, send_tensor_shapes, timers=timers) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) deallocate_output_tensor(output_tensor[0]) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_tensor = recv_forward(recv_tensor_shapes, timers=timers) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if forward_only: send_forward(output_tensor, send_tensor_shapes, timers=timers) if not last_iteration: input_tensor = recv_forward(recv_tensor_shapes, timers=timers) else: output_tensor_grad = \ send_forward_recv_backward(output_tensor, send_tensor_shapes, timers=timers) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) deallocate_output_tensor(output_tensor[0]) # Pop input_tensor and output_tensor from the start of the list for # the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) else: input_tensor = \ send_backward_recv_forward( input_tensor_grad, recv_tensor_shapes, timers=timers) # Run cooldown backward passes. if not forward_only: for i in range(num_warmup_microbatches): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) return losses_reduced
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp: for partition in model: partition.zero_grad_buffer() else: optimizer.zero_grad() forward_backward_func = get_forward_backward_func() losses_reduced = forward_backward_func(forward_step_func, data_iterator, model, optimizer, timers, forward_only=False) # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() for model_module in model: model_module.allreduce_gradients() timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if (mpu.is_pipeline_first_stage(ignore_virtual=True) or mpu.is_pipeline_last_stage(ignore_virtual=True)) and \ mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.is_pipeline_first_stage(ignore_virtual=True): unwrapped_model = model[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = model[-1] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() if args.DDP_impl == 'local': grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successful: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad return {}, skipped_iter, grad_norm, num_zeros_in_grad
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. Returns dictionary with losses if the last stage, empty dict otherwise.""" input_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))] losses_reduced = [] if not forward_only: output_tensor_grads = [[] for _ in range(len(model))] pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() args = get_args() tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) # Compute number of warmup and remaining microbatches. num_model_chunks = len(model) num_microbatches = get_num_microbatches() * num_model_chunks all_warmup_microbatches = False if forward_only: num_warmup_microbatches = num_microbatches else: # Run all forward passes and then all backward passes if number of # microbatches is just the number of pipeline stages. # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on # all workers, followed by more microbatches after depending on # stage ID (more forward passes for earlier stages, later stages can # immediately start with 1F1B). if get_num_microbatches() == pipeline_parallel_size: num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: num_warmup_microbatches = \ (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches def get_model_chunk_id(microbatch_id, forward): """Helper method to get the model chunk ID given the iteration number.""" microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: model_chunk_id = (num_model_chunks - model_chunk_id - 1) return model_chunk_id def forward_step_helper(microbatch_id): """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).""" model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # forward step if mpu.is_pipeline_first_stage(): if len(input_tensors[model_chunk_id]) == \ len(output_tensors[model_chunk_id]): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] output_tensor = forward_step(forward_step_func, data_iterator[model_chunk_id], model[model_chunk_id], input_tensor, losses_reduced) output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: input_tensors[model_chunk_id].pop() output_tensors[model_chunk_id].pop() return output_tensor def backward_step_helper(microbatch_id): """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).""" model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) if mpu.is_pipeline_last_stage(): if len(output_tensor_grads[model_chunk_id]) == 0: output_tensor_grads[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) return input_tensor_grad # Run warmup forward passes. mpu.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(tensor_shape, timers=timers)) for k in range(num_warmup_microbatches): output_tensor = forward_step_helper(k) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True if mpu.is_pipeline_first_stage(ignore_virtual=True): if next_forward_model_chunk_id == 0: recv_prev = False if k == (num_microbatches - 1): recv_prev = False # Don't send tensor downstream if on last stage. if mpu.is_pipeline_last_stage(): output_tensor = None # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). if k == (num_warmup_microbatches - 1) and not forward_only and \ not all_warmup_microbatches: input_tensor_grad = None recv_next = True if mpu.is_pipeline_last_stage(ignore_virtual=True): recv_next = False input_tensor, output_tensor_grad = \ p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, timers=timers) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) else: input_tensor = \ p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, timers=timers) input_tensors[next_forward_model_chunk_id].append(input_tensor) deallocate_output_tensor(output_tensor) # Run 1F1B in steady state. for k in range(num_microbatches_remaining): # Forward pass. forward_k = k + num_warmup_microbatches output_tensor = forward_step_helper(forward_k) # Backward pass. backward_k = k input_tensor_grad = backward_step_helper(backward_k) # Send output_tensor and input_tensor_grad, receive input_tensor # and output_tensor_grad. # Determine if current stage has anything to send in either direction, # otherwise set tensor to None. forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) if mpu.is_pipeline_last_stage(): output_tensor = None backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) if mpu.is_pipeline_first_stage(): input_tensor_grad = None # Determine if peers are sending, and where in data structure to put # received tensors. recv_prev = True if mpu.is_pipeline_first_stage(ignore_virtual=True): # First stage is ahead of last stage by (pipeline_parallel_size - 1). next_forward_model_chunk_id = get_model_chunk_id( forward_k - (pipeline_parallel_size - 1), forward=True) if next_forward_model_chunk_id == (num_model_chunks - 1): recv_prev = False next_forward_model_chunk_id += 1 else: next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) recv_next = True if mpu.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). next_backward_model_chunk_id = get_model_chunk_id( backward_k - (pipeline_parallel_size - 1), forward=False) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 else: next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) # If last iteration, don't receive; we already received one extra # before the start of the for loop. if k == (num_microbatches_remaining - 1): recv_prev = False # Communicate tensors. input_tensor, output_tensor_grad = \ p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, timers=timers) deallocate_output_tensor(output_tensor) # Put input_tensor and output_tensor_grad in data structures in the # right location. if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) if recv_next: output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) # Run cooldown backward passes (flush out pipeline). if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks - 1].append( p2p_communication.recv_backward(tensor_shape, timers=timers)) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) recv_next = True if mpu.is_pipeline_last_stage(ignore_virtual=True): if next_backward_model_chunk_id == (num_model_chunks - 1): recv_next = False if k == (num_microbatches - 1): recv_next = False output_tensor_grads[next_backward_model_chunk_id].append( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, timers=timers)) return losses_reduced
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): """Build the model.""" args = get_args() args.model_type = model_type # Build model. if mpu.get_pipeline_model_parallel_world_size() > 1 and \ args.virtual_pipeline_model_parallel_size is not None: assert model_type != ModelType.encoder_and_decoder, \ "Interleaved schedule not supported for model with both encoder and decoder" model = [] for i in range(args.virtual_pipeline_model_parallel_size): mpu.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() this_model = model_provider_func(pre_process=pre_process, post_process=post_process) this_model.model_type = model_type model.append(this_model) else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() add_encoder = True add_decoder = True if model_type == ModelType.encoder_and_decoder: if mpu.get_pipeline_model_parallel_world_size() > 1: assert args.pipeline_model_parallel_split_rank is not None, \ "Split rank needs to be specified for model with both encoder and decoder" rank = mpu.get_pipeline_model_parallel_rank() split_rank = args.pipeline_model_parallel_split_rank world_size = mpu.get_pipeline_model_parallel_world_size() pre_process = rank == 0 or rank == split_rank post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1)) add_encoder = mpu.is_pipeline_stage_before_split() add_decoder = mpu.is_pipeline_stage_after_split() model = model_provider_func(pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder) else: model = model_provider_func(pre_process=pre_process, post_process=post_process) model.model_type = model_type if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on (tensor, pipeline) ' 'model parallel rank ({}, {}): {}'.format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(), sum([ sum([p.nelement() for p in model_module.parameters()]) for model_module in model ])), flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16 or args.bf16: model = [Float16Module(model_module, args) for model_module in model] if wrap_with_ddp: if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = [ torchDDP(model_module, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) for model_module in model ] elif args.DDP_impl == 'local': model = [ LocalDDP(model_module, args.accumulate_allreduce_grads_in_fp32, args.use_contiguous_buffers_in_local_ddp) for model_module in model ] # broad cast params from data parallel src rank to other data parallel ranks if args.data_parallel_random_init: for model_module in model: model_module.broadcast_params() else: raise NotImplementedError('Unknown DDP implementation specified: ' '{}. Exiting.'.format(args.DDP_impl)) return model
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.""" timers = get_timers() assert len(model) == 1 model = model[0] # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = \ (mpu.get_pipeline_model_parallel_world_size() - mpu.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches input_tensors = [] output_tensors = [] losses_reduced = [] # Run warmup forward passes. for i in range(num_warmup_microbatches): input_tensor = p2p_communication.recv_forward(timers) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) p2p_communication.send_forward(output_tensor, timers) input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_tensor = p2p_communication.recv_forward(timers) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if forward_only: p2p_communication.send_forward(output_tensor, timers) else: output_tensor_grad = \ p2p_communication.send_forward_recv_backward(output_tensor, timers) # Add input_tensor and output_tensor to end of list, then pop from the # start of the list for backward pass. input_tensors.append(input_tensor) output_tensors.append(output_tensor) if forward_only: if not last_iteration: input_tensor = p2p_communication.recv_forward(timers) else: input_tensor, output_tensor = input_tensors.pop( 0), output_tensors.pop(0) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None p2p_communication.send_backward(input_tensor_grad, timers) else: input_tensor = \ p2p_communication.send_backward_recv_forward( input_tensor_grad, timers) # Run cooldown backward passes. if not forward_only: for i in range(num_warmup_microbatches): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = p2p_communication.recv_backward(timers) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) p2p_communication.send_backward(input_tensor_grad, timers) return losses_reduced
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp: for partition in model: partition.zero_grad_buffer() optimizer.zero_grad() forward_backward_func = get_forward_backward_func() losses_reduced = forward_backward_func(forward_step_func, data_iterator, model, optimizer, timers, forward_only=False) # Empty unused memory if args.empty_unused_memory_level >= 1: torch.cuda.empty_cache() # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() for model_module in model: model_module.allreduce_gradients() timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.is_pipeline_first_stage(ignore_virtual=True): unwrapped_model = model[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = model[-1] else: # We do not support the interleaved schedule for T5 yet. unwrapped_model = model[0] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() if args.DDP_impl == 'local': grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) # All-reduce position_embeddings grad across first (encoder) and split (decoder) # stages to ensure that position embeddings parameters stay in sync. # This should only run for T5 models with pipeline parallelism if mpu.is_rank_in_position_embedding_group() and \ mpu.get_pipeline_model_parallel_world_size() > 1 and \ args.pipeline_model_parallel_split_rank is not None: unwrapped_model = model[0] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) assert args.DDP_impl == 'local', \ 'T5 model is only supported with local DDP mode' grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successful: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 # Empty unused memory if args.empty_unused_memory_level >= 2: torch.cuda.empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad return {}, skipped_iter, grad_norm, num_zeros_in_grad
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. optimizer.zero_grad() if mpu.get_pipeline_model_parallel_world_size() > 1: losses_reduced = forward_backward_pipelining(forward_step_func, data_iterator, model, optimizer, timers) else: losses_reduced = forward_backward_no_pipelining( forward_step_func, data_iterator, model, optimizer, timers) # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() model.allreduce_params(reduce_after=False, fp32_allreduce=args.fp32_allreduce) timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \ mpu.get_pipeline_model_parallel_world_size() > 1: unwrapped_model = model while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() torch.distributed.all_reduce(word_embeddings_weight.grad, group=mpu.get_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successfull = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successfull: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 if mpu.is_pipeline_last_stage(): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter return {}, skipped_iter
def __init__(self, init_method, output_layer_init_method, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, pre_process=True, post_process=True): super(ParallelTransformer, self).__init__() args = get_args() self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection self.pre_process = pre_process self.post_process = post_process self.input_tensor = None # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers # Number of layers. assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ 'num_layers must be divisible by pipeline_model_parallel_size' self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size( ) # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( init_method, output_layer_init_method, layer_number, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type) if args.virtual_pipeline_model_parallel_size is not None: assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ 'num_layers_per_stage must be divisible by ' \ 'virtual_pipeline_model_parallel_size' # Number of layers in each model chunk is the number of layers in the stage, # divided by the number of model chunks in a stage. self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0] [2] [4] [6] # Stage 1: [1] [3] [5] [7] # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0, 1] [4, 5] # Stage 1: [2, 3] [6, 7] offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( args.num_layers // args.virtual_pipeline_model_parallel_size) + \ (mpu.get_pipeline_model_parallel_rank() * self.num_layers) else: # Each stage gets a contiguous set of layers. offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers self.layers = torch.nn.ModuleList( [build_layer(i + 1 + offset) for i in range(self.num_layers)]) if self.post_process: # Final layer norm before output. self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)