def forward_backward_no_pipelining(forward_step_func, data_iterator, model, optimizer, timers): """Run forward and backward passes without inter-stage communication.""" args = get_args() losses_reduced = [] for i in range(get_num_microbatches()): timers('forward-compute').start() loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None) output_tensor = loss / get_num_microbatches() losses_reduced.append(loss_reduced) timers('forward-compute').stop() timers('backward-compute').start() output_tensor_grad = None backward_step(optimizer, model, input_tensor=None, output_tensor=output_tensor, output_tensor_grad=None) timers('backward-compute').stop() return losses_reduced
def evaluate(forward_step_func, data_iterator, model, verbose=False): """Evaluation.""" args = get_args() # Turn on evaluation mode which disables dropout. model.eval() total_loss_dict = {} with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) for _ in range(get_num_microbatches()): if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate( tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) else: input_tensor = None # Forward evaluation. output_tensor = forward_step_func(data_iterator, model, input_tensor) if mpu.is_pipeline_last_stage(): _, loss_dict = output_tensor # Reduce across processes. for key in loss_dict: total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \ loss_dict[key] else: communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=False) args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ * args.micro_batch_size \ * get_num_microbatches() # Move model back to the train mode. model.train() for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters * get_num_microbatches() return total_loss_dict
def evaluate(forward_step_func, data_iterator, model, verbose=False): """Evaluation.""" args = get_args() # Turn on evaluation mode which disables dropout. for model_module in model: model_module.eval() total_loss_dict = {} with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format( iteration, args.eval_iters)) forward_backward_func = get_forward_backward_func() loss_dicts = forward_backward_func(forward_step_func, data_iterator, model, optimizer=None, timers=None, forward_only=True) # Empty unused memory if args.empty_unused_memory_level >= 1: torch.cuda.empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): # Reduce across processes. for loss_dict in loss_dicts: for key in loss_dict: total_loss_dict[key] = total_loss_dict.get( key, torch.cuda.FloatTensor([0.0 ])) + loss_dict[key] args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ * args.micro_batch_size \ * get_num_microbatches() # Move model back to the train mode. for model_module in model: model_module.train() for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters * get_num_microbatches() return total_loss_dict
def forward_step_with_communication(forward_step_func, data_iterator, model, input_tensors, output_tensors, losses_reduced, timers): args = get_args() if not mpu.is_pipeline_first_stage(): timers('forward-recv').start() input_tensor, _ = communicate(tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) timers('forward-recv').stop() else: input_tensor = None # Forward model for one step. timers('forward-compute').start() output_tensor = forward_step_func(data_iterator, model, input_tensor) timers('forward-compute').stop() if mpu.is_pipeline_last_stage(): loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() losses_reduced.append(loss_reduced) else: timers('forward-send').start() communicate(tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=False) timers('forward-send').stop() input_tensors.append(input_tensor) output_tensors.append(output_tensor)
def forward_backward_no_pipelining(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run forward and backward passes with no pipeline parallelism (no inter-stage communication). Returns dictionary with losses.""" assert len(model) == 1 model = model[0] context_handler = dummy_handler if isinstance(model, torchDDP): context_handler = model.no_sync losses_reduced = [] input_tensor, output_tensor_grad = None, None with context_handler(): for i in range(get_num_microbatches() - 1): output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) # Run computation for last microbatch out of context handler (want to # synchronize gradients). output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model, optimizer, timers): """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed.""" args = get_args() # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = \ (mpu.get_pipeline_model_parallel_world_size() - mpu.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min( num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches input_tensors = [] output_tensors = [] losses_reduced = [] # Run warmup forward passes. for i in range(num_warmup_microbatches): forward_step_with_communication( forward_step_func, data_iterator, model, input_tensors, output_tensors, losses_reduced, timers) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: if mpu.is_pipeline_first_stage(): input_tensor = None else: timers('forward-recv').start() input_tensor, _ = communicate(tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) timers('forward-recv').stop() # Run 1F1B. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) input_tensor = \ forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model, optimizer, input_tensor, last_iteration, input_tensors, output_tensors, losses_reduced, timers) # Run cooldown backward passes. for i in range(num_warmup_microbatches): backward_step_with_communication( optimizer, model, input_tensors, output_tensors, timers) return losses_reduced
def get_forward_backward_func(): args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: if args.virtual_pipeline_model_parallel_size is not None: forward_backward_func = forward_backward_pipelining_with_interleaving assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \ 'number of microbatches is not divisible by pipeline-parallel ' \ 'size when using interleaved schedule' else: forward_backward_func = forward_backward_pipelining_without_interleaving else: forward_backward_func = forward_backward_no_pipelining return forward_backward_func
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model, optimizer, input_tensor, last_microbatch, input_tensors, output_tensors, losses_reduced, timers): args = get_args() # Forward model for one step. timers('forward-compute').start() output_tensor = forward_step_func(data_iterator, model, input_tensor) timers('forward-compute').stop() if mpu.is_pipeline_last_stage(): loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() output_tensor_grad = None losses_reduced.append(loss_reduced) else: timers('forward-send-backward-recv').start() _, output_tensor_grad = communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=True) timers('forward-send-backward-recv').stop() input_tensors.append(input_tensor) output_tensors.append(output_tensor) input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) # Backward pass for one step. timers('backward-compute').start() input_grad_tensor = \ backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad) timers('backward-compute').stop() if not mpu.is_pipeline_first_stage(): timers('backward-send-forward-recv').start() input_tensor, _ = communicate( tensor_send_next=None, tensor_send_prev=input_grad_tensor, recv_forward=(not last_microbatch), recv_backward=False) timers('backward-send-forward-recv').stop() else: input_tensor = None return input_tensor
def setup_model_and_optimizer(model_provider_func): """Setup model and optimizer.""" args = get_args() model = get_model(model_provider_func) unwrapped_model = model while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module optimizer = get_megatron_optimizer(unwrapped_model) lr_scheduler = get_learning_rate_scheduler(optimizer) if args.load is not None: timers = get_timers() # Extra barrier is added to make sure all ranks report the # max time. torch.distributed.barrier() timers('load checkpoint').start() args.iteration = load_checkpoint(model, optimizer, lr_scheduler) torch.distributed.barrier() timers('load checkpoint').stop() timers.log(['load checkpoint']) else: args.iteration = 0 # We only support local DDP with multiple micro-batches. if get_num_microbatches() > 1: assert args.DDP_impl == 'local' # get model without FP16 and/or TorchDDP wrappers unwrapped_model = model while hasattr(unwrapped_model, 'module'): unwrapped_model = unwrapped_model.module if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) unwrapped_model.init_state_dict_from_bert() return model, optimizer, lr_scheduler
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): """Forward step for passed-in model. If first stage, input tensor is obtained from data_iterator, otherwise passed-in input_tensor is used. Returns output tensor.""" timers = get_timers() timers('forward-compute').start() unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) output_tensor, loss_func = forward_step_func(data_iterator, model) if mpu.is_pipeline_last_stage(): output_tensor = loss_func(output_tensor) loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() losses_reduced.append(loss_reduced) timers('forward-compute').stop() return output_tensor
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): """Forward step for passed-in model. If first stage, input tensor is obtained from data_iterator, otherwise passed-in input_tensor is used. Returns output tensor.""" args = get_args() timers = get_timers() timers('forward-compute').start() unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) unwrap_output_tensor = False if not isinstance(input_tensor, list): input_tensor = [input_tensor] unwrap_output_tensor = True unwrapped_model.set_input_tensor(input_tensor) output_tensor, loss_func = forward_step_func(data_iterator, model) if mpu.is_pipeline_last_stage(): output_tensor = loss_func(output_tensor) loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() losses_reduced.append(loss_reduced) timers('forward-compute').stop() # If T5 model (or other model with encoder and decoder) # and in decoder stack, then send encoder_hidden_state # downstream as well. if mpu.is_pipeline_stage_after_split() and \ args.model_type == ModelType.encoder_and_decoder: return [output_tensor, input_tensor[-1]] if unwrap_output_tensor: return output_tensor return [output_tensor]
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.""" timers = get_timers() assert len(model) == 1 model = model[0] # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = \ (mpu.get_pipeline_model_parallel_world_size() - mpu.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches input_tensors = [] output_tensors = [] losses_reduced = [] # Run warmup forward passes. for i in range(num_warmup_microbatches): input_tensor = p2p_communication.recv_forward(timers) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) p2p_communication.send_forward(output_tensor, timers) input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_tensor = p2p_communication.recv_forward(timers) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if forward_only: p2p_communication.send_forward(output_tensor, timers) else: output_tensor_grad = \ p2p_communication.send_forward_recv_backward(output_tensor, timers) # Add input_tensor and output_tensor to end of list, then pop from the # start of the list for backward pass. input_tensors.append(input_tensor) output_tensors.append(output_tensor) if forward_only: if not last_iteration: input_tensor = p2p_communication.recv_forward(timers) else: input_tensor, output_tensor = input_tensors.pop( 0), output_tensors.pop(0) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None p2p_communication.send_backward(input_tensor_grad, timers) else: input_tensor = \ p2p_communication.send_backward_recv_forward( input_tensor_grad, timers) # Run cooldown backward passes. if not forward_only: for i in range(num_warmup_microbatches): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = p2p_communication.recv_backward(timers) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) p2p_communication.send_backward(input_tensor_grad, timers) return losses_reduced
def train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator): """Train the model function.""" args = get_args() timers = get_timers() # Write args to tensorboard write_args_to_tensorboard() # Turn on training mode which enables dropout. model.train() # Tracking loss. total_loss_dict = {} # Iterations. iteration = args.iteration timers('interval time').start() print_datetime('before the start of training step') report_memory_flag = True while iteration < args.train_iters: update_num_microbatches(args.consumed_train_samples) loss_dict, skipped_iter = train_step(forward_step_func, train_data_iterator, model, optimizer, lr_scheduler) iteration += 1 args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.micro_batch_size * \ get_num_microbatches() # Logging. loss_scale = optimizer.get_loss_scale().item() report_memory_flag = training_log(loss_dict, total_loss_dict, optimizer.param_groups[0]['lr'], iteration, loss_scale, report_memory_flag, skipped_iter) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0 and \ args.do_valid: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) # Checkpointing saved_checkpoint = False if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler) saved_checkpoint = True # Exiting based on duration if args.exit_duration_in_mins: train_time = (time.time() - _TRAIN_START_TIME) / 60.0 done_cuda = torch.cuda.IntTensor( [train_time > args.exit_duration_in_mins]) torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) done = done_cuda.item() if done: if not saved_checkpoint: save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler) print_datetime( 'exiting program after {} minutes'.format(train_time)) sys.exit() # Exiting based on iterations if args.exit_interval and iteration % args.exit_interval == 0: if not saved_checkpoint: save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler) torch.distributed.barrier() print_datetime('exiting program at iteration {}'.format(iteration)) sys.exit() return iteration
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter): """Log training information such as losses, timing, ....""" args = get_args() timers = get_timers() writer = get_tensorboard_writer() # Advanced, skipped, and Nan iterations. advanced_iters_key = 'advanced iterations' skipped_iters_key = 'skipped iterations' nan_iters_key = 'nan iterations' # Advanced iterations. if not skipped_iter: total_loss_dict[advanced_iters_key] = total_loss_dict.get( advanced_iters_key, 0) + 1 else: if advanced_iters_key not in total_loss_dict: total_loss_dict[advanced_iters_key] = 0 # Skipped iterations. total_loss_dict[skipped_iters_key] = total_loss_dict.get( skipped_iters_key, 0) + skipped_iter # Update losses and set nan iterations got_nan = False for key in loss_dict: if not skipped_iter: total_loss_dict[key] = total_loss_dict.get( key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] else: value = loss_dict[key].float().sum().item() is_nan = value == float('inf') or \ value == -float('inf') or \ value != value got_nan = got_nan or is_nan total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int(got_nan) # Logging. timers_to_log = [] def add_to_logging(name): if name in timers.timers: timers_to_log.append(name) add_to_logging('forward-compute') add_to_logging('forward-recv') add_to_logging('forward-send') add_to_logging('forward-send-backward-recv') add_to_logging('backward-compute') add_to_logging('backward-recv') add_to_logging('backward-send') add_to_logging('backward-send-forward-recv') add_to_logging('backward-params-all-reduce') add_to_logging('backward-embedding-all-reduce') add_to_logging('optimizer-copy-to-main-grad') add_to_logging('optimizer-unscale-and-check-inf') add_to_logging('optimizer-clip-main-grad') add_to_logging('optimizer-copy-main-to-model-params') add_to_logging('optimizer') add_to_logging('batch-generator') # Calculate batch size. batch_size = args.micro_batch_size * args.data_parallel_size * \ get_num_microbatches() total_iterations = total_loss_dict[advanced_iters_key] + \ total_loss_dict[skipped_iters_key] # Tensorboard values. if writer and is_last_rank(): writer.add_scalar('learning-rate', learning_rate, iteration) writer.add_scalar('learning-rate vs samples', learning_rate, args.consumed_train_samples) writer.add_scalar('batch-size', batch_size, iteration) writer.add_scalar('batch-size vs samples', batch_size, args.consumed_train_samples) for key in loss_dict: writer.add_scalar(key, loss_dict[key], iteration) writer.add_scalar(key + ' vs samples', loss_dict[key], args.consumed_train_samples) writer.add_scalar('loss-scale', loss_scale, iteration) writer.add_scalar('loss-scale vs samples', loss_scale, args.consumed_train_samples) timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) if iteration % args.log_interval == 0: elapsed_time = timers('interval time').elapsed() elapsed_time_per_iteration = elapsed_time / total_iterations if writer and torch.distributed.get_rank() == 0: writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration) log_string = ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters) log_string += ' consumed samples: {:12d} |'.format( args.consumed_train_samples) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( elapsed_time_per_iteration * 1000.0) log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' global batch size: {:5d} |'.format(batch_size) for key in total_loss_dict: if key not in [ advanced_iters_key, skipped_iters_key, nan_iters_key ]: avg = total_loss_dict[key].item() / \ float(max(1, total_loss_dict[advanced_iters_key])) if avg > 0.0: log_string += ' {}: {:.6E} |'.format(key, avg) total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' number of skipped iterations: {:3d} |'.format( total_loss_dict[skipped_iters_key]) log_string += ' number of nan iterations: {:3d} |'.format( total_loss_dict[nan_iters_key]) total_loss_dict[advanced_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 print_rank_last(log_string) if report_memory_flag and learning_rate > 0.: # Report memory after optimizer state has been initialized. report_memory('(after {} iterations)'.format(iteration)) report_memory_flag = False timers.log(timers_to_log, normalizer=args.log_interval) return report_memory_flag
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp: for partition in model: partition.zero_grad_buffer() else: optimizer.zero_grad() forward_backward_func = get_forward_backward_func() losses_reduced = forward_backward_func(forward_step_func, data_iterator, model, optimizer, timers, forward_only=False) # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() for model_module in model: model_module.allreduce_gradients() timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if (mpu.is_pipeline_first_stage(ignore_virtual=True) or mpu.is_pipeline_last_stage(ignore_virtual=True)) and \ mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.is_pipeline_first_stage(ignore_virtual=True): unwrapped_model = model[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = model[-1] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() if args.DDP_impl == 'local': grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successful: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad return {}, skipped_iter, grad_norm, num_zeros_in_grad
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp: for partition in model: partition.zero_grad_buffer() optimizer.zero_grad() forward_backward_func = get_forward_backward_func() losses_reduced = forward_backward_func(forward_step_func, data_iterator, model, optimizer, timers, forward_only=False) # Empty unused memory if args.empty_unused_memory_level >= 1: torch.cuda.empty_cache() # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() for model_module in model: model_module.allreduce_gradients() timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.is_pipeline_first_stage(ignore_virtual=True): unwrapped_model = model[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = model[-1] else: # We do not support the interleaved schedule for T5 yet. unwrapped_model = model[0] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() if args.DDP_impl == 'local': grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) # All-reduce position_embeddings grad across first (encoder) and split (decoder) # stages to ensure that position embeddings parameters stay in sync. # This should only run for T5 models with pipeline parallelism if mpu.is_rank_in_position_embedding_group() and \ mpu.get_pipeline_model_parallel_world_size() > 1 and \ args.pipeline_model_parallel_split_rank is not None: unwrapped_model = model[0] unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module)) assert args.DDP_impl == 'local', \ 'T5 model is only supported with local DDP mode' grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successful: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 # Empty unused memory if args.empty_unused_memory_level >= 2: torch.cuda.empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad return {}, skipped_iter, grad_norm, num_zeros_in_grad
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. Returns dictionary with losses if the last stage, empty dict otherwise.""" input_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))] losses_reduced = [] if not forward_only: output_tensor_grads = [[] for _ in range(len(model))] pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() args = get_args() tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) # Compute number of warmup and remaining microbatches. num_model_chunks = len(model) num_microbatches = get_num_microbatches() * num_model_chunks all_warmup_microbatches = False if forward_only: num_warmup_microbatches = num_microbatches else: # Run all forward passes and then all backward passes if number of # microbatches is just the number of pipeline stages. # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on # all workers, followed by more microbatches after depending on # stage ID (more forward passes for earlier stages, later stages can # immediately start with 1F1B). if get_num_microbatches() == pipeline_parallel_size: num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: num_warmup_microbatches = \ (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches def get_model_chunk_id(microbatch_id, forward): """Helper method to get the model chunk ID given the iteration number.""" microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: model_chunk_id = (num_model_chunks - model_chunk_id - 1) return model_chunk_id def forward_step_helper(microbatch_id): """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).""" model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # forward step if mpu.is_pipeline_first_stage(): if len(input_tensors[model_chunk_id]) == \ len(output_tensors[model_chunk_id]): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] output_tensor = forward_step(forward_step_func, data_iterator[model_chunk_id], model[model_chunk_id], input_tensor, losses_reduced) output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: input_tensors[model_chunk_id].pop() output_tensors[model_chunk_id].pop() return output_tensor def backward_step_helper(microbatch_id): """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).""" model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) if mpu.is_pipeline_last_stage(): if len(output_tensor_grads[model_chunk_id]) == 0: output_tensor_grads[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) return input_tensor_grad # Run warmup forward passes. mpu.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(tensor_shape, timers=timers)) for k in range(num_warmup_microbatches): output_tensor = forward_step_helper(k) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True if mpu.is_pipeline_first_stage(ignore_virtual=True): if next_forward_model_chunk_id == 0: recv_prev = False if k == (num_microbatches - 1): recv_prev = False # Don't send tensor downstream if on last stage. if mpu.is_pipeline_last_stage(): output_tensor = None # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). if k == (num_warmup_microbatches - 1) and not forward_only and \ not all_warmup_microbatches: input_tensor_grad = None recv_next = True if mpu.is_pipeline_last_stage(ignore_virtual=True): recv_next = False input_tensor, output_tensor_grad = \ p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, timers=timers) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) else: input_tensor = \ p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, timers=timers) input_tensors[next_forward_model_chunk_id].append(input_tensor) deallocate_output_tensor(output_tensor) # Run 1F1B in steady state. for k in range(num_microbatches_remaining): # Forward pass. forward_k = k + num_warmup_microbatches output_tensor = forward_step_helper(forward_k) # Backward pass. backward_k = k input_tensor_grad = backward_step_helper(backward_k) # Send output_tensor and input_tensor_grad, receive input_tensor # and output_tensor_grad. # Determine if current stage has anything to send in either direction, # otherwise set tensor to None. forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) if mpu.is_pipeline_last_stage(): output_tensor = None backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) if mpu.is_pipeline_first_stage(): input_tensor_grad = None # Determine if peers are sending, and where in data structure to put # received tensors. recv_prev = True if mpu.is_pipeline_first_stage(ignore_virtual=True): # First stage is ahead of last stage by (pipeline_parallel_size - 1). next_forward_model_chunk_id = get_model_chunk_id( forward_k - (pipeline_parallel_size - 1), forward=True) if next_forward_model_chunk_id == (num_model_chunks - 1): recv_prev = False next_forward_model_chunk_id += 1 else: next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) recv_next = True if mpu.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). next_backward_model_chunk_id = get_model_chunk_id( backward_k - (pipeline_parallel_size - 1), forward=False) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 else: next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) # If last iteration, don't receive; we already received one extra # before the start of the for loop. if k == (num_microbatches_remaining - 1): recv_prev = False # Communicate tensors. input_tensor, output_tensor_grad = \ p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, timers=timers) deallocate_output_tensor(output_tensor) # Put input_tensor and output_tensor_grad in data structures in the # right location. if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) if recv_next: output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) # Run cooldown backward passes (flush out pipeline). if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks - 1].append( p2p_communication.recv_backward(tensor_shape, timers=timers)) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) recv_next = True if mpu.is_pipeline_last_stage(ignore_virtual=True): if next_backward_model_chunk_id == (num_model_chunks - 1): recv_next = False if k == (num_microbatches - 1): recv_next = False output_tensor_grads[next_backward_model_chunk_id].append( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, timers=timers)) return losses_reduced
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() # Set grad to zero. optimizer.zero_grad() if mpu.get_pipeline_model_parallel_world_size() > 1: losses_reduced = forward_backward_pipelining(forward_step_func, data_iterator, model, optimizer, timers) else: losses_reduced = forward_backward_no_pipelining( forward_step_func, data_iterator, model, optimizer, timers) # All-reduce if needed. if args.DDP_impl == 'local': timers('backward-params-all-reduce').start() model.allreduce_params(reduce_after=False, fp32_allreduce=args.fp32_allreduce) timers('backward-params-all-reduce').stop() # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). timers('backward-embedding-all-reduce').start() if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \ mpu.get_pipeline_model_parallel_world_size() > 1: unwrapped_model = model while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() torch.distributed.all_reduce(word_embeddings_weight.grad, group=mpu.get_embedding_group()) timers('backward-embedding-all-reduce').stop() # Update parameters. timers('optimizer').start() update_successfull = optimizer.step() timers('optimizer').stop() # Update learning rate. if update_successfull: increment = get_num_microbatches() * \ args.micro_batch_size * \ args.data_parallel_size lr_scheduler.step(increment=increment) skipped_iter = 0 else: skipped_iter = 1 if mpu.is_pipeline_last_stage(): # Average loss across microbatches. loss_reduced = {} for key in losses_reduced[0]: losses_reduced_for_key = [x[key] for x in losses_reduced] loss_reduced[key] = sum(losses_reduced_for_key) / len( losses_reduced_for_key) return loss_reduced, skipped_iter return {}, skipped_iter
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.""" args = get_args() timers = get_timers() assert len(model) == 1 model = model[0] # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = \ (mpu.get_pipeline_model_parallel_world_size() - mpu.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) model_type = unwrapped_model.model_type rank = mpu.get_pipeline_model_parallel_rank() recv_tensor_shapes = get_tensor_shapes(rank - 1, model_type) send_tensor_shapes = get_tensor_shapes(rank, model_type) # Input, output tensors only need to be saved when doing backward passes input_tensors = None output_tensors = None if not forward_only: input_tensors = [] output_tensors = [] losses_reduced = [] # Run warmup forward passes. for i in range(num_warmup_microbatches): input_tensor = recv_forward(recv_tensor_shapes, timers=timers) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) send_forward(output_tensor, send_tensor_shapes, timers=timers) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) deallocate_output_tensor(output_tensor[0]) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_tensor = recv_forward(recv_tensor_shapes, timers=timers) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if forward_only: send_forward(output_tensor, send_tensor_shapes, timers=timers) if not last_iteration: input_tensor = recv_forward(recv_tensor_shapes, timers=timers) else: output_tensor_grad = \ send_forward_recv_backward(output_tensor, send_tensor_shapes, timers=timers) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) deallocate_output_tensor(output_tensor[0]) # Pop input_tensor and output_tensor from the start of the list for # the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) else: input_tensor = \ send_backward_recv_forward( input_tensor_grad, recv_tensor_shapes, timers=timers) # Run cooldown backward passes. if not forward_only: for i in range(num_warmup_microbatches): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) return losses_reduced
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader, valid_dataloader, end_of_epoch_callback): """Train the model.""" args = get_args() timers = get_timers() assert get_num_microbatches( ) == 1, "finetuning with gradient accumulation doesn't currently work" # Turn on training mode which enables dropout. for m in model: m.train() # Tracking loss. losses_dict_sum = {} # Starting epoch and iteration start_epoch = args.iteration // args.train_iters_per_epoch start_iteration = args.iteration % args.train_iters_per_epoch iteration = args.iteration # Memory reporting flag. report_memory_flag = True # For each remaining epoch timers('interval-time').start() for epoch in range(start_epoch, args.epochs): print_rank_0('working on epoch {} ...'.format(epoch + 1)) # Set the data loader epoch to shuffle the index iterator. train_dataloader.sampler.set_epoch(args.seed + epoch) # For all the batches in the dataset. for iteration_, batch in enumerate(train_dataloader): # Ignore the iterations before starting value if iteration_ < start_iteration: continue # Set to zero so the next epoch does not skip any batches. start_iteration = 0 # Train for one step. out = train_step(forward_step, batch, model, optimizer, lr_scheduler) losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out iteration += 1 # Logging. params_norm = None if args.log_params_norm: params_norm = calc_params_l2_norm(model) report_memory_flag = training_log( losses_dict, losses_dict_sum, optimizer.param_groups[0]['lr'], iteration, optimizer.get_loss_scale().item(), report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing saved_checkpoint = False if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) saved_checkpoint = True # Evaluation if args.eval_interval and iteration % args.eval_interval == 0: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step, valid_dataloader, model, iteration, False) # Exiting based on iterations if args.exit_interval and iteration % args.exit_interval == 0: if not saved_checkpoint: save_checkpoint(iteration, model, optimizer, lr_scheduler) torch.distributed.barrier() print_rank_0( 'exiting program at iteration {}'.format(iteration)) sys.exit() # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Callback at the end of each epoch. if end_of_epoch_callback is not None: end_of_epoch_callback(model, epoch)