def build_pretraining_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Megatron sampler if args.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( dataset, total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), data_sharding=args.data_sharding) else: raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
def allreduce_gradients(self): """Reduce gradients across data parallel ranks.""" # If we have buffers, simply reduce the data in the buffer. if self._grad_buffers is not None: for _, buffer_ in self._grad_buffers.items(): buffer_.data /= mpu.get_data_parallel_world_size() torch.distributed.all_reduce( buffer_.data, group=mpu.get_data_parallel_group()) else: # Otherwise, bucketize and all-reduce buckets = {} # Pack the buckets. for param in self.module.parameters(): if param.requires_grad and param.grad is not None: tp = param.data.type() if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) param.main_grad = param.grad # For each bucket, all-reduce and copy all-reduced grads. for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = _flatten_dense_tensors(grads) coalesced /= mpu.get_data_parallel_world_size() torch.distributed.all_reduce( coalesced, group=mpu.get_data_parallel_group()) for buf, synced in zip( grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced)
def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. assert False, 'DistributedBatchSampler deprecated, change the implementation' from megatron.data.samplers import DistributedBatchSampler batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def metrics_func(model, epoch, output_predictions=False): print_rank_0('calculating metrics ...') correct = 0 total = 0 if output_predictions: assert mpu.get_data_parallel_world_size() == 1 named_predictions = [] names = 'predictions' for name, dataloader in dataloaders: output = calculate_correct_answers(name, model, dataloader, epoch, output_predictions) if not output_predictions: correct_ans, total_count = output else: correct_ans, total_count, predictions = output named_predictions.append((name, predictions)) names += '_' + name correct += correct_ans total += total_count percent = float(correct) * 100.0 / float(total) print_rank_0(' >> |epoch: {}| overall: correct / total = {} / {} = ' '{:.4f} %'.format(epoch, correct, total, percent)) if output_predictions and torch.distributed.get_rank() == 0: assert args.load is not None filename = os.path.join(args.load, names + '.pt') torch.save(named_predictions, filename)
def accuracy_func_provider(): """Provide function that calculates accuracies.""" args = get_args() data_path = args.data_path crop_size = args.img_dim # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] # Build dataloaders. val_data_path = os.path.join(data_path[0], "val") normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) transform_val = transforms.Compose([ transforms.Resize(crop_size), transforms.CenterCrop(crop_size), transforms.ToTensor(), normalize, ]) dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val) dataloader = build_data_loader( dataset, args.micro_batch_size, num_workers=args.num_workers, drop_last=(mpu.get_data_parallel_world_size() > 1), ) def metrics_func(model, epoch): print_rank_0("calculating metrics ...") correct, total = calculate_correct_answers(model, dataloader, epoch) percent = float(correct) * 100.0 / float(total) print_rank_last(" >> |epoch: {}| overall: correct / total = {} / {} = " "{:.4f} %".format(epoch, correct, total, percent)) return metrics_func
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): mpu.initialize_model_parallel(app_state.model_parallel_size) app_state.model_parallel_group = mpu.get_model_parallel_group() app_state.data_parallel_group = mpu.get_data_parallel_group() app_state.model_parallel_rank = mpu.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = mpu.get_data_parallel_rank() app_state.data_parallel_size = mpu.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}') # TODO: get random seed from PTL seed = os.environ.get("PL_GLOBAL_SEED", 1234) # random seed must be set for megatron model parallel init _set_random_seed(seed)
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() query_tokens, query_mask, \ context_tokens, context_mask, context_indices = get_ict_batch(data_iterator) timers('batch-generator').stop() # Query and Context Types query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0) context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0) # Forward model. query_logits, context_logits = model(query_tokens, query_mask, query_types, context_tokens, context_mask, context_types) micro_batch_size = query_logits.shape[0] # recall we assert that tensor_model_parallel_size == 1 assert mpu.get_tensor_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" global_batch_size = dist.get_world_size() * micro_batch_size all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) # scores are inner products between query and context embeddings retrieval_scores = torch.matmul(all_query_logits, torch.transpose(all_context_logits, 0, 1)) # scaling the retriever scores if args.retriever_score_scaling: retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) softmax_scores = F.log_softmax(retrieval_scores, dim=1) sorted_vals, sorted_indices = torch.topk(softmax_scores, k=softmax_scores.shape[1], sorted=True) def topk_accuracy(k): return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \ for i in range(global_batch_size)]) / global_batch_size]) topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies] labels = torch.arange(global_batch_size).long().cuda() loss = F.nll_loss(softmax_scores, labels, reduction='mean') reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs]) # Scale the retrieval loss loss = loss * mpu.get_data_parallel_world_size() # create stats_dict with retrieval loss and all specified top-k accuracies topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ zip(args.retriever_report_topk_accuracies, reduced_losses[1:])} stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict) return loss, stats_dict
def calculate_correct_answers(name, model, dataloader, epoch, output_predictions): """Calculate correct over total answers and return prediction if the `output_predictions` is true.""" start_time = time.time() model.eval() with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 if output_predictions: # This option is only possible when data parallel size is 1. assert mpu.get_data_parallel_world_size() == 1 softmaxes = [] labels = [] ids = [] for _, batch in enumerate(dataloader): # Run the model forward. tokens, types, labels_, attention_mask = process_batch(batch) logits = model(tokens, attention_mask, types) # Add output predictions. if output_predictions: softmaxes.extend( torch.nn.Softmax(dim=-1)( logits.float()).data.cpu().numpy().tolist()) labels.extend(labels_.data.cpu().numpy().tolist()) ids.extend(batch['uid'].cpu().numpy().tolist()) # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels_) # Add to the counters. total += labels_.size(0) correct += corrects.sum().item() model.train() # Reduce. unreduced = torch.cuda.LongTensor([correct, total]) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) # Print on screen. correct_ans = unreduced[0].item() total_count = unreduced[1].item() percent = float(correct_ans) * 100.0 / float(total_count) elapsed_time = time.time() - start_time print_rank_0(' > |epoch: {}| metrics for {}: correct / total ' '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( epoch, name, correct_ans, total_count, percent, elapsed_time)) if output_predictions: return correct_ans, total_count, (softmaxes, labels, ids) return correct_ans, total_count
def evaluate(forward_step_func, data_iterator, model, verbose=False): """Evaluation.""" args = get_args() # Turn on evaluation mode which disables dropout. model.eval() total_loss_dict = {} with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) for _ in range(get_num_microbatches()): if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate( tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) else: input_tensor = None # Forward evaluation. output_tensor = forward_step_func(data_iterator, model, input_tensor) if mpu.is_pipeline_last_stage(): _, loss_dict = output_tensor # Reduce across processes. for key in loss_dict: total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \ loss_dict[key] else: communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=False) args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ * args.micro_batch_size \ * get_num_microbatches() # Move model back to the train mode. model.train() for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters * get_num_microbatches() return total_loss_dict
def get_rng_state(): """ collect rng state across data parallel ranks """ args = get_args() rng_state = { 'random_rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(), 'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states() } rng_state_list = None if torch.distributed.is_initialized() and \ mpu.get_data_parallel_world_size() > 1 and \ args.data_parallel_random_init: rng_state_list = \ [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather_object( rng_state_list, rng_state, group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] return rng_state_list
def evaluate(forward_step_func, data_iterator, model, verbose=False): """Evaluation.""" args = get_args() # Turn on evaluation mode which disables dropout. for model_module in model: model_module.eval() total_loss_dict = {} with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format( iteration, args.eval_iters)) forward_backward_func = get_forward_backward_func() loss_dicts = forward_backward_func(forward_step_func, data_iterator, model, optimizer=None, timers=None, forward_only=True) # Empty unused memory if args.empty_unused_memory_level >= 1: torch.cuda.empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): # Reduce across processes. for loss_dict in loss_dicts: for key in loss_dict: total_loss_dict[key] = total_loss_dict.get( key, torch.cuda.FloatTensor([0.0 ])) + loss_dict[key] args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ * args.micro_batch_size \ * get_num_microbatches() # Move model back to the train mode. for model_module in model: model_module.train() for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters * get_num_microbatches() return total_loss_dict
def accuracy_func_provider(single_dataset_provider): """Provide function that calculates accuracies.""" args = get_args() # Build dataloaders. datapaths = args.valid_data dataloaders = [] for datapath in datapaths: dataset = single_dataset_provider(datapath) dataloader = build_data_loader( dataset, args.orig_micro_batch_size, num_workers=args.num_workers, drop_last=(mpu.get_data_parallel_world_size() > 1)) dataloaders.append((dataset.dataset_name, dataloader)) def metrics_func(model, epoch, output_predictions=False): print_rank_last('calculating metrics ...') correct = 0 total = 0 if output_predictions: assert mpu.get_data_parallel_world_size() == 1 named_predictions = [] names = 'predictions' for name, dataloader in dataloaders: output = calculate_correct_answers(name, model, dataloader, epoch, output_predictions) if not output_predictions: correct_ans, total_count = output else: correct_ans, total_count, predictions = output named_predictions.append((name, predictions)) names += '_' + name correct += correct_ans total += total_count if is_last_rank(): percent = float(correct) * 100.0 / float(total) print(' >> |epoch: {}| overall: correct / total = {} / {} = ' '{:.4f} %'.format(epoch, correct, total, percent)) if output_predictions and is_last_rank(): assert args.load is not None filename = os.path.join(args.load, names + '.pt') torch.save(named_predictions, filename) return metrics_func
def __init__(self): args = get_args() self.model = None self.dataloader = None self.block_data = None # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint assert not (args.load and args.ict_load) self.using_realm_chkpt = args.ict_load is None self.log_interval = args.indexer_log_interval self.batch_size = args.indexer_batch_size self.load_attributes() self.is_main_builder = mpu.get_data_parallel_rank() == 0 self.num_total_builders = mpu.get_data_parallel_world_size() self.iteration = self.total_processed = 0
def accuracy_func_provider(single_dataset_provider, rank0sampler=False): """Provide function that calculates accuracies.""" args = get_args() print_rank_0("accuracy_func_provider is CALLED") # Build dataloaders datapath = args.valid_data dataset = single_dataset_provider(datapath) drop_last = False if mpu.get_data_parallel_world_size() > 1 and not rank0sampler: drop_last = True print_rank_0(datapath) print_rank_0(rank0sampler) dataloader = build_data_loader(dataset, args.eval_micro_batch_size, num_workers=args.num_workers, drop_last=drop_last, task_collate_fn=task_collate_fn) dataloaders = (dataset.dataset_name, dataloader) def metrics_func(model, epoch, output_predictions=False): print_rank_0('calculating metrics by accuracy func in ORQA...') if output_predictions: assert rank0sampler names = 'predictions' name, dataloader = dataloaders if args.task == "RET-FINETUNE-NQ": start_time = time.time() output = retrieval_loss(model, dataloader) stats_dict, total = output format_string = "" for k, v in stats_dict.items(): format_string += "|{} = {:.2f}".format(k, v / total) print_rank_0("epoch:{}{}".format(epoch, format_string)) print_rank_0("taken time to calcuate metrics {:.3f}".format(\ time.time() - start_time)) else: raise AssertionError("{} Task not supported".format(args.task)) return metrics_func
def __init__(self): args = get_args() self.model = None self.dataloader = None self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint assert not (args.load and args.ict_load) self.log_interval = args.indexer_log_interval self.batch_size = args.indexer_batch_size self.load_attributes() self.is_main_builder = mpu.get_data_parallel_rank() == 0 self.num_total_builders = mpu.get_data_parallel_world_size() self.iteration = self.total_processed = 0
def __init__(self, model, forward_step_fn, neox_args, batch_size=None): self.cache_hook = base.CacheHook(None) self.model = model self.neox_args = neox_args self.tokenizer = neox_args.tokenizer self._device = torch.device(f"cuda:{neox_args.local_rank}") self._eot_token_id = neox_args.tokenizer.eod_id self._max_length = neox_args.max_position_embeddings // 2 self._max_gen_toks = 128 self._vocab_size = neox_args.padded_vocab_size # parallelism args: self.is_main = neox_args.rank == 0 self.is_local_main = neox_args.local_rank == 0 self.is_model_parallel = neox_args.model_parallel_size > 1 self.is_pipe_parallel = self.model.is_pipe_parallel self.is_data_parallel = self.model.is_data_parallel self.is_last_stage = ( True if not self.is_pipe_parallel else model.is_last_stage() ) # only the last stage of the pipeline model will receive the logits self.dp_world_size = mpu.get_data_parallel_world_size() self.dp_rank = mpu.get_data_parallel_rank() self.dp_group = mpu.get_data_parallel_group() self.is_mp_rank_0 = mpu.get_model_parallel_rank() == 0 self._batch_size = batch_size or ( neox_args.batch_size * self.dp_world_size ) # default batch size to bs per gpu * dp size # some utility functions: # we need to patch tokenizer methods, because lm_eval uses them internally: self.tokenizer.encode = self.tokenizer.tokenize self.tokenizer.decode = self.tokenizer.detokenize self._forward_step_fn = partial( forward_step_fn, neox_args=neox_args, timers=None, return_logits=True ) self.generate = partial( generate_samples_from_prompt, neox_args=neox_args, model=model, maximum_tokens=self._max_gen_toks, temperature=0.0, )
def build_data_loader(dataset, batch_size, num_workers, drop_last): """Data loader. Note that batch-size is the local (per GPU) batch-size.""" # Sampler. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank) # Data loader. Note that batch size is the per GPU batch size. data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False, num_workers=num_workers, drop_last=drop_last, pin_memory=True) return data_loader
def build_pretraining_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Megatron sampler batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
def make_data_loader(dataset, neox_args): """Buld dataloader given an input dataset.""" if dataset is None: return None # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = neox_args.batch_size * world_size num_workers = neox_args.num_workers # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def get_one_epoch_dataloader(dataset, batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if batch_size is None: batch_size = args.batch_size global_batch_size = batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() if micro_batch_size is None: micro_batch_size = args.micro_batch_size num_workers = args.num_workers # Use megatron's sampler with consumed samples set to 0 as # this is only for evaluation and don't intend to resume half way. # Also, set the drop last to false as don't intend to remove # the last batch batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=0, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), drop_last=False) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def calculate_correct_answers(name, model, dataloader, epoch, output_predictions): """Calculate correct over total answers and return prediction if the `output_predictions` is true.""" args = get_args() start_time = time.time() model.eval() saved_batch_size = args.micro_batch_size with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 if output_predictions: # This option is only possible when data parallel size is 1. assert mpu.get_data_parallel_world_size() == 1 softmaxes = [] labels = [] ids = [] for _, batch in enumerate(dataloader): # Run the model forward. tokens, types, labels_, attention_mask = process_batch(batch) # For evaluation only mode we use drop_last = False to get all the # samples, which means we might not have a full batch, so we # adjust batch_size here to actual batch size of data actual_batch_size = len(labels_) # ... applying sample_multiplier if necessary ds = dataloader.dataset if hasattr(ds, 'sample_multiplier'): actual_batch_size *= ds.sample_multiplier args.micro_batch_size = actual_batch_size if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate( tensor_send_next=None, tensor_send_prev=None, recv_forward=True, recv_backward=False) else: input_tensor = None # Forward model. if mpu.is_pipeline_first_stage(): assert input_tensor is None output_tensor = model(tokens, attention_mask, tokentype_ids=types) else: assert input_tensor is not None output_tensor = model(input_tensor, attention_mask) if mpu.is_pipeline_last_stage(): logits = output_tensor # Add output predictions. if output_predictions: softmaxes.extend(torch.nn.Softmax(dim=-1)( logits.float()).data.cpu().numpy().tolist()) labels.extend(labels_.data.cpu().numpy().tolist()) ids.extend(batch['uid'].cpu().numpy().tolist()) # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels_) # Add to the counters. total += labels_.size(0) correct += corrects.sum().item() else: communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_forward=False, recv_backward=False) model.train() args.micro_batch_size = saved_batch_size # Reduce. if mpu.is_pipeline_last_stage(): unreduced = torch.cuda.LongTensor([correct, total]) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) # Print on screen. correct_ans = unreduced[0].item() total_count = unreduced[1].item() percent = float(correct_ans) * 100.0 / float(total_count) elapsed_time = time.time() - start_time print_rank_last(' > |epoch: {}| metrics for {}: correct / total ' '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( epoch, name, correct_ans, total_count, percent, elapsed_time)) if output_predictions: return correct_ans, total_count, (softmaxes, labels, ids) return correct_ans, total_count if output_predictions: return 0, 0, () return 0, 0
def get_global_batch_size(args): return args.batch_size * mpu.get_data_parallel_world_size() * args.gas
def calculate_correct_answers(name, model, dataloader, epoch, output_predictions): """Calculate correct over total answers and return prediction if the `output_predictions` is true.""" args = get_args() forward_backward_func = get_forward_backward_func() start_time = time.time() for m in model: m.eval() saved_micro_batch_size = args.micro_batch_size saved_global_batch_size = args.global_batch_size ds = dataloader.dataset if hasattr(ds, 'sample_multiplier'): # If our dataset as a sample_multiplier attribute that means # each "sample" from the dataset actually has multiple samples # that will collapse into the batch dimension (for example in # the RACE dataset that has several options), we need to # account for that when setting the micro batch size. sample_multiplier = ds.sample_multiplier else: sample_multiplier = 1 micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel def loss_func(output_predictions, labels, output_tensor): logits = output_tensor loss_dict = {} # Add output predictions. if output_predictions: assert False loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)( logits.float()).data.cpu().numpy().tolist() loss_dict['labels'] = labels.data.cpu().numpy().tolist() loss_dict['ids'] = batch['uid'].cpu().numpy().tolist() # Compute the correct answers. predicted = torch.argmax(logits, dim=-1) corrects = (predicted == labels) # Add to the counters. loss_dict['total'] = labels.size(0) loss_dict['correct'] = corrects.sum().item() return 0, loss_dict # defined inside to capture output_predictions def correct_answers_forward_step(batch, model): try: batch_ = next(batch) except BaseException: batch_ = batch tokens, types, labels, attention_mask = process_batch(batch_) # Forward model. args = get_args() output_tensor = model(tokens, attention_mask, tokentype_ids=types) return output_tensor, partial(loss_func, output_predictions, labels) with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 if output_predictions: # This option is only possible when data parallel size is 1. assert mpu.get_data_parallel_world_size() == 1 softmaxes = [] labels = [] ids = [] for _, batch in enumerate(dataloader): # For evaluation only mode we use drop_last = False to get all the # samples, which means we might not have a full batch, so we # adjust batch_size here to actual batch size of data actual_batch_size = len(batch['label']) # ... applying sample_multiplier if necessary args.micro_batch_size = actual_batch_size * sample_multiplier args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, optimizer=None, timers=None, forward_only=True) for loss_dict in loss_dicts: if output_predictions: softmaxes.extend(loss_dict['softmaxes']) labels.extend(loss_dict['labels']) ids.extend(loss_dict['ids']) total += loss_dict['total'] correct += loss_dict['correct'] for m in model: m.train() args.micro_batch_size = saved_micro_batch_size args.global_batch_size = saved_global_batch_size # Reduce. if mpu.is_pipeline_last_stage(): unreduced = torch.cuda.LongTensor([correct, total]) torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) # Print on screen. correct_ans = unreduced[0].item() total_count = unreduced[1].item() percent = float(correct_ans) * 100.0 / float(total_count) elapsed_time = time.time() - start_time print_rank_last( ' > |epoch: {}| metrics for {}: correct / total ' '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( epoch, name, correct_ans, total_count, percent, elapsed_time)) if output_predictions: return correct_ans, total_count, (softmaxes, labels, ids) return correct_ans, total_count if output_predictions: return 0, 0, () return 0, 0
def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider): """XXX""" args = get_args() (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Number of train/valid/test samples. train_iters = args.train_iters eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_iters * global_batch_size, eval_iters * global_batch_size, test_iters * global_batch_size] print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( train_val_test_num_samples) # Build dataloders. train_dataloader = make_data_loader(train_ds) valid_dataloader = make_data_loader(valid_ds) test_dataloader = make_data_loader(test_ds) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 do_valid = valid_dataloader is not None and args.eval_iters > 0 do_test = test_dataloader is not None and args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = flags[0].item() args.do_valid = flags[1].item() args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = args.iteration % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'. format(train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'. format(valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator
def train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator): """Train the model function.""" args = get_args() timers = get_timers() # Turn on training mode which enables dropout. model.train() # Tracking loss. total_loss_dict = {} # Iterations. iteration = args.iteration timers('interval time').start() report_memory_flag = True data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size while iteration < args.train_iters and \ (args.train_tokens is None or args.tokens < args.train_tokens): loss_dict, skipped_iter = train_step(forward_step_func, train_data_iterator, model, optimizer, lr_scheduler) iteration += 1 if args.curriculum_learning: args.tokens += global_batch_size * args.curriculum_seqlen else: args.tokens += global_batch_size * args.seq_length # Logging. loss_scale = None if args.fp16: loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale report_memory_flag = training_log(loss_dict, total_loss_dict, optimizer.param_groups[0]['lr'], iteration, loss_scale, report_memory_flag, skipped_iter, model=model) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) # Evaluation # XXX temporarily disabled for ZeRO-3 """ if args.eval_interval and iteration % args.eval_interval == 0 and \ args.do_valid: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) """ if args.exit_interval and iteration % args.exit_interval == 0: torch.distributed.barrier() time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') rank = torch.distributed.get_rank() print_rank_0('rank: {} | time: {} | exiting the program at ' 'iteration {}'.format(rank, time_str, iteration)) sys.exit() return iteration
def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor): args = get_args() local_batch_size = query_tokens.shape[0] group, rank, world_size = get_group_world_size_rank() # recall we assert that model_parallel_size == 1 global_batch_size = world_size * local_batch_size query_logits, context_logits = output_tensor if world_size > 1: input_ = torch.empty_like(context_logits).copy_(\ context_logits).detach_() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank].copy_(input_) torch.distributed.all_gather(tensor_list, input_, group=group) # Check if all-gather happens in order assert tensor_list[rank].sum().item() == \ context_logits.sum().item() # Preserves the gradient tensor_list[rank] = context_logits all_context_logits = torch.cat(tensor_list, dim=0).contiguous() # Query tensors input_ = torch.empty_like(query_logits).copy_(\ query_logits).detach_() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank].copy_(input_) torch.distributed.all_gather(tensor_list, input_, group=group) # Check if all-gather happens in order assert tensor_list[rank].sum().item() == query_logits.sum().item() # Preserves the gradient tensor_list[rank] = query_logits all_query_logits = torch.cat(tensor_list, dim=0).contiguous() else: all_query_logits = query_logits all_context_logits = context_logits retrieval_scores = torch.matmul( all_query_logits, torch.transpose(all_context_logits, 0, 1)) # Scaling the retrieval scores if args.retriever_score_scaling: retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) if args.train_with_neg: # if the world size is 3, local batch size is 4, and # local context size is 8, what we want is # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] labels = [] local_context_size = context_tokens.shape[0] for i in range(world_size): j = i * local_context_size labels.extend(list(range(j, j + local_batch_size))) labels = torch.LongTensor(labels).cuda() assert len(labels) == global_batch_size else: labels = torch.arange(global_batch_size).long().cuda() # Cross-entropy loss. softmax_scores = F.log_softmax(retrieval_scores, dim=1) loss = F.nll_loss(softmax_scores, labels, reduction='mean') max_score, max_idxs = torch.max(softmax_scores, 1) correct_predictions_count = (max_idxs == labels).sum().float() # Reduce loss for logging. reduced_loss = average_losses_across_data_parallel_group([loss, \ correct_predictions_count]) # Loss scaling for correct losses in Supervised Retrieval loss = loss * mpu.get_data_parallel_world_size() return loss, { 'lm loss': reduced_loss[0], 'correct_prediction_count': reduced_loss[1] }
def train(forward_step_func, model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator): """Train the model function.""" args = get_args() timers = get_timers() # Write args to tensorboard write_args_to_tensorboard() # Turn on training mode which enables dropout. model.train() # Tracking loss. total_loss_dict = {} # Iterations. iteration = args.iteration timers('interval time').start() print_datetime('before the start of training step') report_memory_flag = True while iteration < args.train_iters: update_num_microbatches(args.consumed_train_samples) loss_dict, skipped_iter = train_step(forward_step_func, train_data_iterator, model, optimizer, lr_scheduler) iteration += 1 args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.micro_batch_size * \ get_num_microbatches() # Logging. loss_scale = optimizer.get_loss_scale().item() report_memory_flag = training_log(loss_dict, total_loss_dict, optimizer.param_groups[0]['lr'], iteration, loss_scale, report_memory_flag, skipped_iter) # Autoresume if args.adlr_autoresume and \ (iteration % args.adlr_autoresume_interval == 0): check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0 and \ args.do_valid: prefix = 'iteration {}'.format(iteration) evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, False) # Checkpointing saved_checkpoint = False if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler) saved_checkpoint = True # Exiting based on duration if args.exit_duration_in_mins: train_time = (time.time() - _TRAIN_START_TIME) / 60.0 done_cuda = torch.cuda.IntTensor( [train_time > args.exit_duration_in_mins]) torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) done = done_cuda.item() if done: if not saved_checkpoint: save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler) print_datetime( 'exiting program after {} minutes'.format(train_time)) sys.exit() # Exiting based on iterations if args.exit_interval and iteration % args.exit_interval == 0: if not saved_checkpoint: save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler) torch.distributed.barrier() print_datetime('exiting program at iteration {}'.format(iteration)) sys.exit() return iteration