def train_batch(self, data_iterator, epoch_idx, batch_idx): if self.neox_args.is_pipe_parallel: reduced_loss = megatron_train.train_step_pipe( neox_args=self.neox_args, timers=self.timers, model=self.model, data_iterator=data_iterator, ) else: losses = [] for _ in range(self.neox_args.gradient_accumulation_steps): self.timers("forward").start() loss = megatron_train.forward_step( neox_args=self.neox_args, timers=self.timers, data_iterator=data_iterator, model=self.model, ) self.timers("forward").stop() losses.append(loss) # Calculate gradients, reduce across processes, and clip. self.timers("backward").start() megatron_train.backward_step( neox_args=self.neox_args, timers=self.timers, optimizer=self.optimizer, model=self.model, loss=loss, ) self.timers("backward").stop() # Update parameters. self.timers("optimizer").start() if self.neox_args.deepspeed: self.model.step() else: raise ValueError("Must be using deepspeed to run neox") self.timers("optimizer").stop() reduced_loss = { "lm_loss": megatron_utils.reduce_losses(losses).mean() } if self.neox_args.precision == "fp16" and self.model.optimizer.overflow: skipped_iter = 1 else: skipped_iter = 0 self.neox_args.iteration += 1 self.overflow_monitor.check( skipped_iter) # check for repeated overflow if self.neox_args.log_gradient_noise_scale: # log noise scale if applicable self.noise_scale_logger.update() # get learning rate (if present) - if doing soft prompt tuning + pipe parallel, you # may have no tunable parameters on a specific rank if self.optimizer.param_groups: lr = self.optimizer.param_groups[0].get("lr", 0) else: lr = 0 # Logging. self.report_memory_flag, additional_metrics = megatron_train.training_log( neox_args=self.neox_args, timers=self.timers, loss_dict=reduced_loss, total_loss_dict=self.total_train_loss_dict, learning_rate=lr, iteration=self.neox_args.iteration, loss_scale=self.optimizer.cur_scale if self.neox_args.precision == "fp16" else None, report_memory_flag=self.report_memory_flag, skipped_iter=skipped_iter, model=self.model, optimizer=self.optimizer, noise_scale_logger=self.noise_scale_logger, return_metrics=True, ) if (additional_metrics is not None and additional_metrics["num_nans"] == 0 and additional_metrics["num_skipped"] == 0): self.tflops = additional_metrics["flops_per_sec_per_gpu"] / 10**12 if (self.neox_args.exit_interval and self.neox_args.iteration % self.neox_args.exit_interval == 0): torch.distributed.barrier() time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") rank = torch.distributed.get_rank() megatron_utils.print_rank_0( "time: {} | exiting the program at iteration {}".format( time_str, self.neox_args.iteration)) self.context.set_stop_requested(True) return reduced_loss
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader, valid_dataloader, end_of_epoch_callback): """Train the model.""" args = get_args() timers = get_timers() # Turn on training mode which enables dropout. model.train() # Tracking loss. losses_dict_sum = {} # Starting epoch and iteration start_epoch = args.iteration // args.train_iters_per_epoch start_iteration = args.iteration % args.train_iters_per_epoch iteration = args.iteration # Memory reporting flag. report_memory_flag = True # For each remaining epoch timers('interval time').start() for epoch in range(start_epoch, args.epochs): print_rank_0('working on epoch {} ...'.format(epoch + 1)) # Set the data loader epoch to shuffle the index iterator. train_dataloader.sampler.set_epoch(args.seed + epoch) # For all the batches in the dataset. for iteration_, batch in enumerate(train_dataloader): # Ignore the iterations before starting value if iteration_ < start_iteration: continue # Set to zero so the next epoch does not skip any batches. start_iteration = 0 # Train for one step. losses_dict, _ = train_step(forward_step, batch, model, optimizer, lr_scheduler) iteration += 1 # Logging. report_memory_flag = training_log(losses_dict, losses_dict_sum, optimizer.param_groups[0]['lr'], iteration, optimizer.loss_scale, report_memory_flag) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step, valid_dataloader, model, iteration, False) # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Callback at the end of each epoch. if end_of_epoch_callback is not None: end_of_epoch_callback(model, epoch)
def _train(model, optimizer, lr_scheduler, forward_step, train_dataloader, valid_dataloader, end_of_epoch_callback): """Train the model.""" args = get_args() timers = get_timers() assert get_num_microbatches( ) == 1, "finetuning with gradient accumulation doesn't currently work" # Turn on training mode which enables dropout. for m in model: m.train() # Tracking loss. losses_dict_sum = {} # Starting epoch and iteration start_epoch = args.iteration // args.train_iters_per_epoch start_iteration = args.iteration % args.train_iters_per_epoch iteration = args.iteration # Memory reporting flag. report_memory_flag = True # For each remaining epoch timers('interval-time').start() for epoch in range(start_epoch, args.epochs): print_rank_0('working on epoch {} ...'.format(epoch + 1)) # Set the data loader epoch to shuffle the index iterator. train_dataloader.sampler.set_epoch(args.seed + epoch) # For all the batches in the dataset. for iteration_, batch in enumerate(train_dataloader): # Ignore the iterations before starting value if iteration_ < start_iteration: continue # Set to zero so the next epoch does not skip any batches. start_iteration = 0 # Train for one step. out = train_step(forward_step, batch, model, optimizer, lr_scheduler) losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out iteration += 1 # Logging. params_norm = None if args.log_params_norm: params_norm = calc_params_l2_norm(model) report_memory_flag = training_log( losses_dict, losses_dict_sum, optimizer.param_groups[0]['lr'], iteration, optimizer.get_loss_scale().item(), report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing saved_checkpoint = False if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) saved_checkpoint = True # Evaluation if args.eval_interval and iteration % args.eval_interval == 0: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step, valid_dataloader, model, iteration, False) # Exiting based on iterations if args.exit_interval and iteration % args.exit_interval == 0: if not saved_checkpoint: save_checkpoint(iteration, model, optimizer, lr_scheduler) torch.distributed.barrier() print_rank_0( 'exiting program at iteration {}'.format(iteration)) sys.exit() # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Callback at the end of each epoch. if end_of_epoch_callback is not None: end_of_epoch_callback(model, epoch)