def metrics_func(model, epoch, output_predictions=False): print_rank_last('calculating metrics ...') correct = 0 total = 0 if output_predictions: assert mpu.get_data_parallel_world_size() == 1 named_predictions = [] names = 'predictions' for name, dataloader in dataloaders: output = calculate_correct_answers(name, model, dataloader, epoch, output_predictions) if not output_predictions: correct_ans, total_count = output else: correct_ans, total_count, predictions = output named_predictions.append((name, predictions)) names += '_' + name correct += correct_ans total += total_count if is_last_rank(): percent = float(correct) * 100.0 / float(total) print(' >> |epoch: {}| overall: correct / total = {} / {} = ' '{:.4f} %'.format(epoch, correct, total, percent)) if output_predictions and is_last_rank(): assert args.load is not None filename = os.path.join(args.load, names + '.pt') torch.save(named_predictions, filename)
def load_state_dict(self, state_dict, strict=True): """Customized load.""" self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if self.post_process: if self._multichoice_head_key in state_dict: self.multichoice_head.load_state_dict( state_dict[self._multichoice_head_key], strict=strict) else: print_rank_last( '***WARNING*** could not find {} in the checkpoint, ' 'initializing to random'.format( self._multichoice_head_key))
def load_state_dict(self, state_dict, strict=True): """Customized load.""" self.language_model.load_state_dict( state_dict[self._language_model_key], strict=strict) if mpu.is_pipeline_last_stage(): if self._classification_head_key in state_dict: self.classification_head.load_state_dict( state_dict[self._classification_head_key], strict=strict) else: print_rank_last( '***WARNING*** could not find {} in the checkpoint, ' 'initializing to random'.format( self._classification_head_key))
def load_state_dict(checkpoint_name): try: state_dict = torch.load(checkpoint_name, map_location="cpu") except ModuleNotFoundError: from megatron.fp16_deprecated import loss_scaler # For backward compatibility. print_rank_last( " > deserializing using the old code structure ...") sys.modules["fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler"] sys.modules["megatron.fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler"] state_dict = torch.load(checkpoint_name, map_location="cpu") sys.modules.pop("fp16.loss_scaler", None) sys.modules.pop("megatron.fp16.loss_scaler", None) return state_dict
def evaluate_and_print_results(prefix, forward_step_func, data_iterator, model, iteration, verbose=False): """Helper function to evaluate and dump results on screen.""" args = get_args() writer = get_tensorboard_writer() total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose) string = ' validation loss at {} | '.format(prefix) for key in total_loss_dict: string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) ppl = math.exp(min(20, total_loss_dict[key].item())) string += '{} PPL: {:.6E} | '.format(key, ppl) if writer and is_last_rank(): writer.add_scalar('{} value-validation'.format(key), total_loss_dict[key].item(), iteration) writer.add_scalar('{} ppl-validation'.format(key), ppl, iteration) writer.add_scalar('{} value-validation vs samples'.format(key), total_loss_dict[key].item(), args.consumed_train_samples) writer.add_scalar('{} ppl-validation vs samples'.format(key), ppl, args.consumed_train_samples) length = len(string) + 1 print_rank_last('-' * length) print_rank_last(string) print_rank_last('-' * length)
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): """Load a model checkpoint and return the iteration.""" from megatron import get_args from megatron import mpu from megatron import print_rank_last from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import set_checkpoint_version from megatron.checkpointing import check_checkpoint_args from megatron.checkpointing import update_num_microbatches if mpu.get_data_parallel_rank() == 0: # at dp rank 0, we still follow the native load_checkpoint by megatron from megatron.checkpointing import load_checkpoint as load_checkpoint_native return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg) args = get_args() load_dir = getattr(args, load_arg) if isinstance(model, DistributedDataParallel): model = model.module # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_last("WARNING: could not find the metadata file {} ".format( tracker_filename)) print_rank_last( " will not load any checkpoints and will start from " "random") return 0 # Otherwise, read the tracker file and either set the iteration or # mark it as a release checkpoint. iteration = 0 release = False with open(tracker_filename, "r") as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == "release" if not release: print_rank_last( "ERROR: Invalid metadata file {}. Exiting".format( tracker_filename)) sys.exit() assert iteration > 0 or release, "error parsing metadata file {}".format( tracker_filename) # Checkpoint. checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0) checkpoint_name_local = get_fmoe_checkpoint_name( load_dir, iteration, release, mpu.get_data_parallel_rank()) print_rank_last( " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later" .format( checkpoint_name_rank0, mpu.get_data_parallel_rank(), checkpoint_name_local, iteration, )) # Load the checkpoint. def load_state_dict(checkpoint_name): try: state_dict = torch.load(checkpoint_name, map_location="cpu") except ModuleNotFoundError: from megatron.fp16_deprecated import loss_scaler # For backward compatibility. print_rank_last( " > deserializing using the old code structure ...") sys.modules["fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler"] sys.modules["megatron.fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler"] state_dict = torch.load(checkpoint_name, map_location="cpu") sys.modules.pop("fp16.loss_scaler", None) sys.modules.pop("megatron.fp16.loss_scaler", None) except BaseException: print_rank_last("could not load the checkpoint") sys.exit() return state_dict state_dict_rank0 = load_state_dict(checkpoint_name_rank0) state_dict_local = load_state_dict(checkpoint_name_local) state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16) # set checkpoint version set_checkpoint_version(state_dict.get("checkpoint_version", 0)) # Set iteration. if args.finetune or release: iteration = 0 else: try: iteration = state_dict["iteration"] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict["total_iters"] except KeyError: print_rank_last("A metadata file exists but unable to load " "iteration from checkpoint {}, exiting".format( checkpoint_name_local)) sys.exit() # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 if "args" in state_dict: checkpoint_args = state_dict["args"] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, "consumed_train_samples", 0) update_num_microbatches(consumed_samples=args.consumed_train_samples) args.consumed_valid_samples = getattr(checkpoint_args, "consumed_valid_samples", 0) else: print_rank_last("could not find arguments in the checkpoint ...") # Model. model.load_state_dict(state_dict["model"]) # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(state_dict["optimizer"]) if lr_scheduler is not None: lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) except KeyError: print_rank_last("Unable to load optimizer from checkpoint {}. " "Specify --no-load-optim or --finetune to prevent " "attempting to load the optimizer state, " "exiting ...".format(checkpoint_name_local)) sys.exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(state_dict["random_rng_state"]) np.random.set_state(state_dict["np_rng_state"]) torch.set_rng_state(state_dict["torch_rng_state"]) torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) mpu.get_cuda_rng_tracker().set_states( state_dict["rng_tracker_states"]) except KeyError: print_rank_last("Unable to load optimizer from checkpoint {}. " "Specify --no-load-rng or --finetune to prevent " "attempting to load the optimizer state, " "exiting ...".format(checkpoint_name_local)) sys.exit() torch.distributed.barrier() print_rank_last( " successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}" .format(args.load, iteration)) return iteration
def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint with expert parallel """ # TODO: update patch from megatron import get_args from megatron import mpu from megatron import print_rank_last expert_dp_comm = "none" if mpu.get_data_parallel_rank() == 0: # at dp rank 0, we still follows the native load_checkpoint by megatron from megatron.checkpointing import save_checkpoint as save_checkpoint_native save_checkpoint_native(iteration, model, optimizer, lr_scheduler) return args = get_args() # Only rank zero of the data parallel writes to the disk. if isinstance(model, DistributedDataParallel): model = model.module print_rank_last("saving checkpoint at iteration {:7d} to {}".format( iteration, args.save)) # Arguments, iteration, and model. state_dict = {} state_dict["model"] = model.state_dict_for_save_checkpoint( keep_vars=(mpu.get_data_parallel_rank() > 0)) def extract_expert_param(state_dict, expert_dp_comm="none"): state_dict_new = state_dict.__class__() for k, v in state_dict.items(): # megatron uses both dict and OrderedDict in its state_dict if isinstance(v, (OrderedDict, dict)): v_new = extract_expert_param(v, expert_dp_comm) if len(v_new) > 0: state_dict_new[k] = v_new elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm: state_dict_new[k] = v.detach() return state_dict_new state_dict["model"] = extract_expert_param(state_dict["model"], expert_dp_comm) # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: state_dict["optimizer"] = optimizer.state_dict() param_global_idx = 0 for param_group in optimizer.optimizer.param_groups: for param in param_group["params"]: if not (hasattr(param, "dp_comm") and param.dp_comm == expert_dp_comm): # this parameter is not an expert parameter # thus there is no need to save its state in current rank # since it has been saved by data parallel rank 0 if args.fp16: # fp16 optimizer may have empty state due to overflow state_dict["optimizer"]["optimizer"]["state"].pop( param_global_idx, None) else: state_dict["optimizer"]["state"].pop( param_global_idx) param_global_idx += 1 if args.fp16: state_dict["optimizer"]["optimizer"].pop("param_groups") # fp32_from_fp16_params in state_dict is not a copy # but a reference to optimizer.fp32_from_fp16_params, # changing it in state_dict will change # optimizer.fp32_from_fp16_params as well # thus we create an empty fp32_from_fp16_params in state_dict # and only insert expert parameters. fp32_from_fp16_params = state_dict["optimizer"][ "fp32_from_fp16_params"] state_dict["optimizer"]["fp32_from_fp16_params"] = [] for param_group in fp32_from_fp16_params: param_group_copy = [] for param in param_group: param_copy = (param if hasattr(param, "dp_comm") and param.dp_comm == expert_dp_comm else None) param_group_copy.append(param_copy) state_dict["optimizer"]["fp32_from_fp16_params"].append( param_group_copy) else: state_dict["optimizer"].pop("param_groups") # Save. checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration) from megatron.checkpointing import ensure_directory_exists from megatron.checkpointing import get_checkpoint_tracker_filename ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) # Wait so everyone is done (necessary) torch.distributed.barrier() if torch.distributed.get_rank() == 0: print( " successfully saved checkpoint at iteration {:7d} to {}".format( iteration, args.save), flush=True, ) # And update the latest iteration if torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, "w") as f: f.write(str(iteration)) # Wait so everyone is done (not necessary) torch.distributed.barrier()
def calculate_correct_answers(name, model, dataloader, epoch, output_predictions): """Calculate correct over total answers and return prediction if the `output_predictions` is true.""" args = get_args() start_time = time.time() model.eval() saved_batch_size = args.micro_batch_size with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 if output_predictions: # This option is only possible when data parallel size is 1. assert mpu.get_data_parallel_world_size() == 1 softmaxes = [] labels = [] ids = [] for _, batch in enumerate(dataloader): # Run the model forward. tokens, types, labels_, attention_mask = process_batch(batch) # For evaluation only mode we use drop_last = False to get all the # samples, which means we might not have a full batch, so we # adjust batch_size here to actual batch size of data actual_batch_size = len(labels_) # ... applying sample_multiplier if necessary ds = dataloader.dataset if hasattr(ds, 'sample_multiplier'): actual_batch_size *= ds.sample_multiplier args.micro_batch_size = actual_batch_size if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate( tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) else: input_tensor = None # Forward model. if mpu.is_pipeline_first_stage(): assert input_tensor is None output_tensor = model(tokens, attention_mask, tokentype_ids=types) else: assert input_tensor is not None output_tensor = model(input_tensor, attention_mask) if mpu.is_pipeline_last_stage(): logits = output_tensor # Add output predictions. if output_predictions: softmaxes.extend(torch.nn.Softmax(dim=-1)( logits.float()).data.cpu().numpy().tolist()) labels.extend(labels_.data.cpu().numpy().tolist()) ids.extend(batch['uid'].cpu().numpy().tolist()) # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels_) # Add to the counters. total += labels_.size(0) correct += corrects.sum().item() else: communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=False) model.train() args.micro_batch_size = saved_batch_size # Reduce. if mpu.is_pipeline_last_stage(): unreduced = torch.cuda.LongTensor([correct, total]) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) # Print on screen. correct_ans = unreduced[0].item() total_count = unreduced[1].item() percent = float(correct_ans) * 100.0 / float(total_count) elapsed_time = time.time() - start_time print_rank_last(' > |epoch: {}| metrics for {}: correct / total ' '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( epoch, name, correct_ans, total_count, percent, elapsed_time)) if output_predictions: return correct_ans, total_count, (softmaxes, labels, ids) return correct_ans, total_count if output_predictions: return 0, 0, () return 0, 0
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter): """Log training information such as losses, timing, ....""" args = get_args() timers = get_timers() writer = get_tensorboard_writer() # Advanced, skipped, and Nan iterations. advanced_iters_key = 'advanced iterations' skipped_iters_key = 'skipped iterations' nan_iters_key = 'nan iterations' # Advanced iterations. if not skipped_iter: total_loss_dict[advanced_iters_key] = total_loss_dict.get( advanced_iters_key, 0) + 1 else: if advanced_iters_key not in total_loss_dict: total_loss_dict[advanced_iters_key] = 0 # Skipped iterations. total_loss_dict[skipped_iters_key] = total_loss_dict.get( skipped_iters_key, 0) + skipped_iter # Update losses and set nan iterations got_nan = False for key in loss_dict: if not skipped_iter: total_loss_dict[key] = total_loss_dict.get( key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] else: value = loss_dict[key].float().sum().item() is_nan = value == float('inf') or \ value == -float('inf') or \ value != value got_nan = got_nan or is_nan total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int(got_nan) # Logging. timers_to_log = [] def add_to_logging(name): if name in timers.timers: timers_to_log.append(name) add_to_logging('forward-compute') add_to_logging('forward-recv') add_to_logging('forward-send') add_to_logging('forward-send-backward-recv') add_to_logging('backward-compute') add_to_logging('backward-recv') add_to_logging('backward-send') add_to_logging('backward-send-forward-recv') add_to_logging('backward-params-all-reduce') add_to_logging('backward-embedding-all-reduce') add_to_logging('optimizer-copy-to-main-grad') add_to_logging('optimizer-unscale-and-check-inf') add_to_logging('optimizer-clip-main-grad') add_to_logging('optimizer-copy-main-to-model-params') add_to_logging('optimizer') add_to_logging('batch-generator') # Calculate batch size. batch_size = args.micro_batch_size * args.data_parallel_size * \ get_num_microbatches() total_iterations = total_loss_dict[advanced_iters_key] + \ total_loss_dict[skipped_iters_key] # Tensorboard values. if writer and is_last_rank(): writer.add_scalar('learning-rate', learning_rate, iteration) writer.add_scalar('learning-rate vs samples', learning_rate, args.consumed_train_samples) writer.add_scalar('batch-size', batch_size, iteration) writer.add_scalar('batch-size vs samples', batch_size, args.consumed_train_samples) for key in loss_dict: writer.add_scalar(key, loss_dict[key], iteration) writer.add_scalar(key + ' vs samples', loss_dict[key], args.consumed_train_samples) writer.add_scalar('loss-scale', loss_scale, iteration) writer.add_scalar('loss-scale vs samples', loss_scale, args.consumed_train_samples) timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) if iteration % args.log_interval == 0: elapsed_time = timers('interval time').elapsed() elapsed_time_per_iteration = elapsed_time / total_iterations if writer and torch.distributed.get_rank() == 0: writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration) log_string = ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters) log_string += ' consumed samples: {:12d} |'.format( args.consumed_train_samples) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( elapsed_time_per_iteration * 1000.0) log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' global batch size: {:5d} |'.format(batch_size) for key in total_loss_dict: if key not in [ advanced_iters_key, skipped_iters_key, nan_iters_key ]: avg = total_loss_dict[key].item() / \ float(max(1, total_loss_dict[advanced_iters_key])) if avg > 0.0: log_string += ' {}: {:.6E} |'.format(key, avg) total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' number of skipped iterations: {:3d} |'.format( total_loss_dict[skipped_iters_key]) log_string += ' number of nan iterations: {:3d} |'.format( total_loss_dict[nan_iters_key]) total_loss_dict[advanced_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 print_rank_last(log_string) if report_memory_flag and learning_rate > 0.: # Report memory after optimizer state has been initialized. report_memory('(after {} iterations)'.format(iteration)) report_memory_flag = False timers.log(timers_to_log, normalizer=args.log_interval) return report_memory_flag
def calculate_correct_answers(name, model, dataloader, epoch, output_predictions): """Calculate correct over total answers and return prediction if the `output_predictions` is true.""" args = get_args() forward_backward_func = get_forward_backward_func() start_time = time.time() for m in model: m.eval() saved_micro_batch_size = args.micro_batch_size saved_global_batch_size = args.global_batch_size ds = dataloader.dataset if hasattr(ds, 'sample_multiplier'): # If our dataset as a sample_multiplier attribute that means # each "sample" from the dataset actually has multiple samples # that will collapse into the batch dimension (for example in # the RACE dataset that has several options), we need to # account for that when setting the micro batch size. sample_multiplier = ds.sample_multiplier else: sample_multiplier = 1 micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel def loss_func(output_predictions, labels, output_tensor): logits = output_tensor loss_dict = {} # Add output predictions. if output_predictions: assert False loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)( logits.float()).data.cpu().numpy().tolist() loss_dict['labels'] = labels.data.cpu().numpy().tolist() loss_dict['ids'] = batch['uid'].cpu().numpy().tolist() # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels) # Add to the counters. loss_dict['total'] = labels.size(0) loss_dict['correct'] = corrects.sum().item() return 0, loss_dict # defined inside to capture output_predictions def correct_answers_forward_step(batch, model): try: batch_ = next(batch) except BaseException: batch_ = batch tokens, types, labels, attention_mask = process_batch(batch_) # Forward model. args = get_args() output_tensor = model(tokens, attention_mask, tokentype_ids=types) return output_tensor, partial(loss_func, output_predictions, labels) with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 if output_predictions: # This option is only possible when data parallel size is 1. assert mpu.get_data_parallel_world_size() == 1 softmaxes = [] labels = [] ids = [] for _, batch in enumerate(dataloader): # For evaluation only mode we use drop_last = False to get all the # samples, which means we might not have a full batch, so we # adjust batch_size here to actual batch size of data actual_batch_size = len(batch['label']) # ... applying sample_multiplier if necessary args.micro_batch_size = actual_batch_size * sample_multiplier args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, optimizer=None, timers=None, forward_only=True) for loss_dict in loss_dicts: if output_predictions: softmaxes.extend(loss_dict['softmaxes']) labels.extend(loss_dict['labels']) ids.extend(loss_dict['ids']) total += loss_dict['total'] correct += loss_dict['correct'] for m in model: m.train() args.micro_batch_size = saved_micro_batch_size args.global_batch_size = saved_global_batch_size # Reduce. if mpu.is_pipeline_last_stage(): unreduced = torch.cuda.LongTensor([correct, total]) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) # Print on screen. correct_ans = unreduced[0].item() total_count = unreduced[1].item() percent = float(correct_ans) * 100.0 / float(total_count) elapsed_time = time.time() - start_time print_rank_last( ' > |epoch: {}| metrics for {}: correct / total ' '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( epoch, name, correct_ans, total_count, percent, elapsed_time)) if output_predictions: return correct_ans, total_count, (softmaxes, labels, ids) return correct_ans, total_count if output_predictions: return 0, 0, () return 0, 0
def metrics_func(model, epoch): print_rank_0("calculating metrics ...") correct, total = calculate_correct_answers(model, dataloader, epoch) percent = float(correct) * 100.0 / float(total) print_rank_last(" >> |epoch: {}| overall: correct / total = {} / {} = " "{:.4f} %".format(epoch, correct, total, percent))