def train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator): """Train the model function.""" args = get_args() timers = get_timers() # Turn on training mode which enables dropout. model.train() # Tracking loss. total_loss_dict = {} # Iterations. iteration = args.iteration timers('interval time').start() report_memory_flag = True while iteration < args.train_iters: loss_dict, skipped_iter = train_step(forward_step_func, train_data_iterator, model, optimizer, lr_scheduler) iteration += 1 # Logging. loss_scale = None if args.fp16: loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale report_memory_flag = training_log(loss_dict, total_loss_dict, optimizer.param_groups[0]['lr'], iteration, loss_scale, report_memory_flag, skipped_iter) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0 and \ args.do_valid: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) if args.exit_interval and iteration % args.exit_interval == 0: torch.distributed.barrier() time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') rank = torch.distributed.get_rank() print_rank_0('rank: {} | time: {} | exiting the program at ' 'iteration {}'.format(rank, time_str, iteration)) sys.exit() return iteration
def save(self, context: DeepSpeedTrialContext, path: pathlib.Path) -> None: self.neox_args.save = str(path) save_checkpoint( neox_args=self.neox_args, iteration=self.neox_args.iteration, model=self.model, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, )
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): timers = get_timers() # Extra barrier is added to make sure # all ranks report the max time. torch.distributed.barrier() timers('save checkpoint').start() save_checkpoint(iteration, model, optimizer, lr_scheduler) torch.distributed.barrier() timers('save checkpoint').stop() timers.log(['save checkpoint'])
def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler): """Check for autoresume signal and exit if it is received.""" args = get_args() autoresume = get_adlr_autoresume() # Add barrier to ensure consistnecy. torch.distributed.barrier() if autoresume.termination_requested(): if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) print_rank_0(">>> autoresume termination request found!") if torch.distributed.get_rank() == 0: autoresume.request_resume() print_rank_0(">>> training terminated. Returning") sys.exit(0)
def run_checkpoint_test(yaml_list=None, param_dict=None): from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict, clear_data=True) # save model checkpoint save_checkpoint( neox_args=args_loaded, iteration=42, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, ) # reload model from checkpoint ( reloaded_model, reloaded_optimizer, reloaded_lr_scheduler, args_reloaded, ) = model_setup(yaml_list, param_dict, clear_data=False) iteration = load_checkpoint( neox_args=args_reloaded, model=reloaded_model, optimizer=reloaded_optimizer, lr_scheduler=reloaded_lr_scheduler, ) # ensure same checkpoint is loaded assert (iteration == 42 ), "run_checkpoint_test() iteration loaded from checkpoint correct" # check all weight groups are the same for idx, ((n1, p1), (n2, p2)) in enumerate( zip( list(model.module.named_parameters()), list(reloaded_model.module.named_parameters()), )): assert n1 == n2 params_equal = (p1 == p2).all().item() assert params_equal, "run_checkpoint_test() params equal: " + str(n1)
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader, valid_dataloader, end_of_epoch_callback): """Train the model.""" args = get_args() timers = get_timers() # Turn on training mode which enables dropout. model.train() # Tracking loss. losses_dict_sum = {} # Starting epoch and iteration start_epoch = args.iteration // args.train_iters_per_epoch start_iteration = args.iteration % args.train_iters_per_epoch iteration = args.iteration # Memory reporting flag. report_memory_flag = True # For each remaining epoch timers('interval time').start() for epoch in range(start_epoch, args.epochs): print_rank_0('working on epoch {} ...'.format(epoch + 1)) # Set the data loader epoch to shuffle the index iterator. train_dataloader.sampler.set_epoch(args.seed + epoch) # For all the batches in the dataset. for iteration_, batch in enumerate(train_dataloader): # Ignore the iterations before starting value if iteration_ < start_iteration: continue # Set to zero so the next epoch does not skip any batches. start_iteration = 0 # Train for one step. losses_dict, _ = train_step(forward_step, batch, model, optimizer, lr_scheduler) iteration += 1 # Logging. report_memory_flag = training_log(losses_dict, losses_dict_sum, optimizer.param_groups[0]['lr'], iteration, optimizer.loss_scale, report_memory_flag) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step, valid_dataloader, model, iteration, False) # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Callback at the end of each epoch. if end_of_epoch_callback is not None: end_of_epoch_callback(model, epoch)
def pretrain(train_valid_test_dataset_provider, model_provider, forward_step_func, extra_args_provider=None, args_defaults={}): """Main training program. This function will run the followings in the order provided: 1) initialize Megatron. 2) setup model, optimizer and lr schedule using the model_provider. 3) call train_val_test_data_provider to get train/val/test datasets. 4) train the modle using the forward_step_func. Arguments: train_valid_test_dataset_provider: a function that takes the size of train/valid/test dataset and returns `train, valid, test` datasets. model_provider: a function that returns a vanilla version of the model. By vanilla we mean a simple model on cpu with no fp16 or ddp. forward_step_func: a function that takes a `data iterator` and `model`, and returns a `loss` scalar with a dictionary with key:values being the info we would like to monitor during training, for example `lm-loss: value`. We also require that this function add `batch generator` to the timers class. extra_args_provider: a function that takes a parser and adds arguments to it. It is used for programs to add their own arguments. args_defaults: a dictionary from argument-name to argument-value. It to set already parse arguments. """ # Initalize and get arguments, timers, and Tensorboard writer. initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) args = get_args() timers = get_timers() # Model, optimizer, and learning rate. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) timers('model and optimizer').stop() # Data stuff. timers('train/valid/test data iterators').start() train_data_iterator, valid_data_iterator, test_data_iterator \ = build_train_valid_test_data_iterators( train_valid_test_dataset_provider) timers('train/valid/test data iterators').stop() # Print setup timing. print_rank_0('done with setups ...') timers.log(['model and optimizer', 'train/valid/test data iterators']) print_rank_0('training ...') iteration = 0 if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator) if args.do_valid: prefix = 'the end of training for val data' evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.do_test: # Run on test data. prefix = 'the end of training for test data' evaluate_and_print_results(prefix, forward_step_func, test_data_iterator, model, 0, True)
def pretrain(train_valid_test_dataset_provider, model_provider, forward_step_func, extra_args_provider=None, args_defaults={}): """Main training program. This function will run the followings in the order provided: 1) initialize Megatron. 2) setup model, optimizer and lr schedule using the model_provider. 3) call train_val_test_data_provider to get train/val/test datasets. 4) train the modle using the forward_step_func. Arguments: train_valid_test_dataset_provider: a function that takes the size of train/valid/test dataset and returns `train, valid, test` datasets. model_provider: a function that returns a vanilla version of the model. By vanilla we mean a simple model on cpu with no fp16 or ddp. forward_step_func: a function that takes a `data iterator` and `model`, and returns a `loss` scalar with a dictionary with key:values being the info we would like to monitor during training, for example `lm-loss: value`. We also require that this function add `batch generator` to the timers class. extra_args_provider: a function that takes a parser and adds arguments to it. It is used for programs to add their own arguments. args_defaults: a dictionary from argument-name to argument-value. It to set already parse arguments. """ # Initalize and get arguments, timers, and Tensorboard writer. initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) # Adjust the startup time so it reflects the largest value. # This will be closer to what scheduler will see (outside of # image ... launches. global _TRAIN_START_TIME start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME]) torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN) _TRAIN_START_TIME = start_time_tensor.item() print_rank_0('time to initialize megatron (seconds): {:.3f}'.format( time.time() - _TRAIN_START_TIME)) print_datetime('after megatron is initialized') args = get_args() timers = get_timers() # Model, optimizer, and learning rate. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) timers('model and optimizer').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') # Data stuff. timers('train/valid/test data iterators').start() train_data_iterator, valid_data_iterator, test_data_iterator \ = build_train_valid_test_data_iterators( train_valid_test_dataset_provider) timers('train/valid/test data iterators').stop() print_datetime('after dataloaders are built') # Print setup timing. print_rank_0('done with setups ...') timers.log(['model and optimizer', 'train/valid/test data iterators']) print_rank_0('training ...') iteration = 0 if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator) print_datetime('after training is done') if args.do_valid: prefix = 'the end of training for val data' evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.do_test: # Run on test data. prefix = 'the end of training for test data' evaluate_and_print_results(prefix, forward_step_func, test_data_iterator, model, 0, True)
def main(): # Arguments do sanity checks on the world size, but we don't care, # so trick it into thinking we are plenty of processes os.environ["WORLD_SIZE"] = f'{2**31}' # Args set_global_variables(extra_args_provider=get_mp_merge_args, args_defaults={ 'use_cpu_initialization': True, 'micro_batch_size': 1, 'no_load_optim': True, 'no_load_rng': True, 'no_save_optim': True, 'no_save_rng': True, 'save_interval': 1 }) args = get_args() if args.pipeline_model_parallel_size > 1: print( "Checkpoints with pipeline model parallelism are not currently supported." ) exit() model_type = args.model_type orig_tensor_model_parallel_size = args.tensor_model_parallel_size args.tensor_model_parallel_size = 1 tokenizer = rebuild_tokenizer(args) print('\n merging model parallel partitions ...') print( ' > number of partitions: {}'.format(orig_tensor_model_parallel_size)) print(' > checkpoint path: {}'.format(args.load)) print(' > model parameters:') print(' number of tokens ................ {} '.format( tokenizer.vocab_size)) print(' number of layers ................ {}'.format(args.num_layers)) print(' hidden size ..................... {}'.format(args.hidden_size)) print(' number of attention heads ....... {}'.format( args.num_attention_heads)) print(' maximum position embeddings ..... {}'.format( args.max_position_embeddings)) # Full model. print('> building the full model ...') mpu.initialize.set_tensor_model_parallel_world_size(1) mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_pipeline_model_parallel_world_size(1) mpu.initialize.set_pipeline_model_parallel_rank(0) merged_model = get_model(model_type) # Build and load partitions. partitions = [] iteration = 0 args.tensor_model_parallel_size = orig_tensor_model_parallel_size tokenizer = rebuild_tokenizer(args) mpu.initialize.set_tensor_model_parallel_world_size( args.tensor_model_parallel_size) for rank in range(args.tensor_model_parallel_size): # Reset these since load_checkpoint asserts they are 0, but we are loading # multiple checkpoints in the same process and they get set each time args.consumed_train_samples = 0 args.consumed_valid_samples = 0 mpu.initialize.set_tensor_model_parallel_rank(rank) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) model_ = get_model(model_type) print(f'> loading {checkpoint_name} ...') load_checkpoint(model_, None, None) print(f'> checkpoint version {get_checkpoint_version()}') partitions.append(model_) # Parameter generators so we can loop through them semiltaneouly. merged_params_gen = merged_model.named_parameters() partitions_params_gen = [ partition.named_parameters() for partition in partitions ] while True: try: # Get the params and check names. name, merged_param = next(merged_params_gen) print(' > working on {} ...'.format(name)) print(' merged type: {}, size: {}'.format( merged_param.dtype, list(merged_param.size()))) partitions_param = [] for rank, partition_params_gen in enumerate(partitions_params_gen): partition_name, partition_param = next(partition_params_gen) assert partition_name == name partitions_param.append(partition_param) print(' partition {} type: {}, size: {}'.format( rank, partition_param.dtype, list(partition_param.size()))) # For the non-parallel parameters, simply copy the rank 0 values. if not hasattr(merged_param, 'tensor_model_parallel'): print(' none-parallel parameter, simple copy from rank 0') with torch.no_grad(): merged_param.data.copy_(partitions_param[0].data) # For parallel parameters, merge the values else: dim = merged_param.partition_dim stride = merged_param.partition_stride print( f' parallel parameter merge with stride {stride} along ' f'dimention {dim}') merge_partitions(merged_param, partitions_param, dim, stride) except StopIteration: break partitions = [] args.tensor_model_parallel_size = 1 args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size assert args.num_layers % args.pipeline_model_parallel_size == 0, \ 'num_layers must be divisible by target pipeline model parallel size' layers_per_part = args.num_layers // args.pipeline_model_parallel_size tokenizer = rebuild_tokenizer(args) mpu.initialize.set_tensor_model_parallel_world_size( args.tensor_model_parallel_size) mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_pipeline_model_parallel_world_size( args.pipeline_model_parallel_size) # regex to parse out layer number from param name layer_re = re.compile('layers\.([0-9]+)') if args.pipeline_model_parallel_size > 1: merged_params = {} for name, merged_param in merged_model.named_parameters(): merged_params[name] = merged_param for rank in range(args.pipeline_model_parallel_size): mpu.initialize.set_pipeline_model_parallel_rank(rank) model = get_model(model_type) def update_layer_num(m): # TODO! This assumes no interleaved pipeline execution layer = int(m.group(1)) layer += rank * layers_per_part return f'layers.{layer}' for dst_name, partition_param in model.named_parameters(): if dst_name == "word_embeddings.weight": # See comment in MegatronModule.initialize_word_embeddings() src_name = "language_model.embedding.word_embeddings.weight" else: # Translate destination layer number (0-N for each partition) # to source layer number (single-model layer number) src_name = re.sub(layer_re, update_layer_num, dst_name) print( f" > copying {src_name} to {dst_name} in rank {rank}'s model" ) partition_param.data.copy_(merged_params[src_name].data) partitions.append(model) else: partitions = [merged_model] for rank, model in enumerate(partitions): mpu.initialize.set_pipeline_model_parallel_rank(rank) print(f"> saving rank {rank}'s model") save_checkpoint(iteration, model, None, None) print('done :-)')
def pretrain(neox_args): """Main training program. This function will run the following in the order provided: 1) initialize Megatron. 2) setup model, optimizer and lr schedule 3) call train_val_test_data_provider to get train/val/test datasets. 4) train the model. Arguments: neox_args: an instance of NeoXArgs containing the configuration for pretrain """ # setup logging and timers init_wandb(neox_args=neox_args) timers = Timers(use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) # Initialize and get arguments, timers, and Tensorboard writer. initialize_megatron(neox_args=neox_args) # Model, optimizer, and learning rate. timers("model and optimizer").start() model, optimizer, lr_scheduler = setup_model_and_optimizer( neox_args=neox_args, use_cache=False) timers("model and optimizer").stop() # Data stuff. timers("train/valid/test data iterators").start() ( train_data_iterator, valid_data_iterator, test_data_iterator, ) = build_train_valid_test_data_iterators(neox_args=neox_args) timers("train/valid/test data iterators").stop() # Print setup timing. print_rank_0("done with setups ...") timers.log(["model and optimizer", "train/valid/test data iterators"]) print_rank_0("training ...") iteration = 0 if neox_args.do_train and neox_args.train_iters > 0: iteration = train( neox_args=neox_args, timers=timers, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, train_data_iterator=train_data_iterator, valid_data_iterator=valid_data_iterator, ) if neox_args.do_valid: prefix = "the end of training for val data" evaluate_and_print_results( neox_args=neox_args, prefix=prefix, forward_step_func=forward_step, data_iterator=valid_data_iterator, model=model, iteration=iteration, verbose=False, timers=timers, ) if neox_args.save and iteration != 0: save_checkpoint( neox_args=neox_args, iteration=iteration, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, ) if neox_args.do_test: # Run on test data. prefix = "the end of training for test data" evaluate_and_print_results( neox_args=neox_args, prefix=prefix, forward_step_func=forward_step, data_iterator=test_data_iterator, model=model, iteration=0, # iteration 0 in order to always use full test data verbose=True, timers=timers, )
def train( neox_args, timers, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator, ): """Train the model function.""" # Turn on training mode which enables dropout. model.train() # Tracking loss. total_loss_dict = {} # Iterations. iteration = neox_args.iteration timers("interval time").start() report_memory_flag = True # get noise scale logger (if neox_args.log_gradient_noise_scale is True) noise_scale_logger = get_noise_scale_logger(neox_args) # to monitor if we've skipped many iterations in a row and trigger an early exit overflow_monitor = OverflowMonitor(optimizer) while iteration < neox_args.train_iters: loss_dict, skipped_iter = train_step( neox_args=neox_args, timers=timers, data_iterator=train_data_iterator, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, ) iteration += 1 overflow_monitor.check(skipped_iter) # check for repeated overflow if neox_args.log_gradient_noise_scale: # log noise scale if applicable noise_scale_logger.update() # get learning rate (if present) - if doing soft prompt tuning + pipe parallel, you # may have no tunable parameters on a specific rank if optimizer.param_groups: lr = optimizer.param_groups[0].get("lr", 0) else: lr = 0 # Logging. report_memory_flag = training_log( neox_args=neox_args, timers=timers, loss_dict=loss_dict, total_loss_dict=total_loss_dict, learning_rate=lr, iteration=iteration, loss_scale=optimizer.cur_scale if neox_args.precision == "fp16" else None, report_memory_flag=report_memory_flag, skipped_iter=skipped_iter, model=model, optimizer=optimizer, noise_scale_logger=noise_scale_logger, ) # Checkpointing if (neox_args.save and neox_args.save_interval and iteration % neox_args.save_interval == 0): save_checkpoint( neox_args=neox_args, iteration=iteration, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, ) # Evaluation if (neox_args.eval_interval and iteration % neox_args.eval_interval == 0 and neox_args.do_valid): prefix = "iteration {}".format(iteration) evaluate_and_print_results( neox_args=neox_args, prefix=prefix, forward_step_func=forward_step, data_iterator=valid_data_iterator, model=model, iteration=iteration, verbose=False, timers=timers, ) if neox_args.exit_interval and iteration % neox_args.exit_interval == 0: torch.distributed.barrier() time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") rank = torch.distributed.get_rank() print_rank_0( "rank: {} | time: {} | exiting the program at iteration {}". format(rank, time_str, iteration)) sys.exit() return iteration
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader, valid_dataloader, end_of_epoch_callback): """Train the model.""" args = get_args() timers = get_timers() assert get_num_microbatches( ) == 1, "finetuning with gradient accumulation doesn't currently work" # Turn on training mode which enables dropout. for m in model: m.train() # Tracking loss. losses_dict_sum = {} # Starting epoch and iteration start_epoch = args.iteration // args.train_iters_per_epoch start_iteration = args.iteration % args.train_iters_per_epoch iteration = args.iteration # Memory reporting flag. report_memory_flag = True # For each remaining epoch timers('interval-time').start() for epoch in range(start_epoch, args.epochs): print_rank_0('working on epoch {} ...'.format(epoch + 1)) # Set the data loader epoch to shuffle the index iterator. train_dataloader.sampler.set_epoch(args.seed + epoch) # For all the batches in the dataset. for iteration_, batch in enumerate(train_dataloader): # Ignore the iterations before starting value if iteration_ < start_iteration: continue # Set to zero so the next epoch does not skip any batches. start_iteration = 0 # Train for one step. out = train_step(forward_step, batch, model, optimizer, lr_scheduler) losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out iteration += 1 # Logging. params_norm = None if args.log_params_norm: params_norm = calc_params_l2_norm(model) report_memory_flag = training_log( losses_dict, losses_dict_sum, optimizer.param_groups[0]['lr'], iteration, optimizer.get_loss_scale().item(), report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing saved_checkpoint = False if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) saved_checkpoint = True # Evaluation if args.eval_interval and iteration % args.eval_interval == 0: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step, valid_dataloader, model, iteration, False) # Exiting based on iterations if args.exit_interval and iteration % args.exit_interval == 0: if not saved_checkpoint: save_checkpoint(iteration, model, optimizer, lr_scheduler) torch.distributed.barrier() print_rank_0( 'exiting program at iteration {}'.format(iteration)) sys.exit() # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Callback at the end of each epoch. if end_of_epoch_callback is not None: end_of_epoch_callback(model, epoch)