def build_pretraining_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Megatron sampler if args.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( dataset, total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), data_sharding=args.data_sharding) else: raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
def _set_mips_index(self): """ Create a Faiss Flat index with inner product as the metric to search against """ try: import faiss except ImportError: raise Exception( "Error: Please install faiss to use FaissMIPSIndex") if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Building index", flush=True) cpu_index = faiss.IndexFlatIP(self.embed_size) if self.use_gpu: # create resources and config for GpuIndex config = faiss.GpuMultipleClonerOptions() config.shard = True config.useFloat16 = True gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) self.mips_index = faiss.IndexIDMap(gpu_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on GPU", flush=True) else: # CPU index supports IDs so wrap with IDMap self.mips_index = faiss.IndexIDMap(cpu_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it # when the FAISS structure is built if self.embed_data is not None: self.add_embed_data(self.embed_data)
def load_from_file(self): """Populate members from instance saved to file""" if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Unpickling BlockData", flush=True) state_dict = pickle.load(open(self.embedding_path, 'rb')) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Finished unpickling BlockData\n", flush=True) self.embed_data = state_dict['embed_data']
def add_block_embed_data(self, all_block_data): """Add the embedding of each block to the underlying FAISS index""" # this assumes the embed_data is a dict : {int: np.array<float>} block_indices, block_embeds = zip(*all_block_data.embed_data.items()) # the embeddings have to be entered in as float32 even though the math internally is done with float16. block_embeds_arr = np.float32(np.array(block_embeds)) block_indices_arr = np.array(block_indices) # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with if self.use_gpu: for i, idx in enumerate(block_indices): self.id_map[i] = idx # we no longer need the embedding data since it's in the index now all_block_data.clear() if self.use_gpu: self.block_mips_index.add(block_embeds_arr) else: self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">>> Finished adding block data to index", flush=True)
def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() if args.deepspeed: save_ds_checkpoint(iteration, model, args) else: # Only rank zero of the data parallel writes to the disk. if isinstance(model, torchDDP): model = model.module if mpu.get_data_parallel_rank() == 0: # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration state_dict['model'] = model.state_dict_for_save_checkpoint() # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: state_dict['optimizer'] = optimizer.state_dict() if lr_scheduler is not None: state_dict['lr_scheduler'] = lr_scheduler.state_dict() # RNG states. if not args.no_save_rng: state_dict['random_rng_state'] = random.getstate() state_dict['np_rng_state'] = np.random.get_state() state_dict['torch_rng_state'] = torch.get_rng_state() state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() state_dict['rng_tracker_states'] \ = mpu.get_cuda_rng_tracker().get_states() # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) print( 'global rank {} is saving checkpoint at iteration {:7d} to {}'. format(torch.distributed.get_rank(), iteration, checkpoint_name)) ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) print(' successfully saved {}'.format(checkpoint_name)) # Wait so everyone is done (necessary) torch.distributed.barrier() # And update the latest iteration if torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, 'w') as f: f.write(str(iteration)) # Wait so everyone is done (necessary) torch.distributed.barrier() if args.keep_last_n_checkpoints is not None: delete_old_checkpoints(args.save, args.keep_last_n_checkpoints) # Wait so everyone is done (not necessary) torch.distributed.barrier()
def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. assert False, 'DistributedBatchSampler deprecated, change the implementation' from megatron.data.samplers import DistributedBatchSampler batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. model = model_provider_func() # Print number of parameters. if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training.""" if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) return model if args.DDP_impl == 'local': model = LocalDDP(model) return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl))
def init_state_dict_from_bert(self): """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining""" args = get_args() tracker_filename = get_checkpoint_tracker_filename(args.bert_load) if not os.path.isfile(tracker_filename): raise FileNotFoundError("Could not find BERT load for ICT") with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) assert iteration > 0 checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) try: state_dict = torch.load(checkpoint_name, map_location='cpu') except BaseException: raise ValueError("Could not load checkpoint") # load the LM state dict into each model model_dict = state_dict['model']['language_model'] self.query_model.language_model.load_state_dict(model_dict) self.block_model.language_model.load_state_dict(model_dict) # give each model the same ict_head to begin with as well query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[ self._query_key]['ict_head'] self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): mpu.initialize_model_parallel(app_state.model_parallel_size) app_state.model_parallel_group = mpu.get_model_parallel_group() app_state.data_parallel_group = mpu.get_data_parallel_group() app_state.model_parallel_rank = mpu.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = mpu.get_data_parallel_rank() app_state.data_parallel_size = mpu.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}') # TODO: get random seed from PTL seed = os.environ.get("PL_GLOBAL_SEED", 1234) # random seed must be set for megatron model parallel init _set_random_seed(seed)
def check_forward_pass(neox_args, model, checkpoint_logits, inference): # do forward pass with loaded checkpoint logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference) # check if logits is not None and checkpoint_logits is not None: # this could be the case for non-final pipeline stages if not (logits == checkpoint_logits).all().item(): if mpu.get_data_parallel_rank() == 0: print(" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result") assert torch.isclose(logits, checkpoint_logits).all().item(), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"
def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" args = get_args() if isinstance(model, torchDDP): model = model.module load_path = args.load if from_realm_chkpt else args.ict_load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) # assert iteration > 0 checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] if only_query_model: ict_state_dict.pop('context_model') if only_block_model: ict_state_dict.pop('question_model') model.load_state_dict(ict_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model
def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ selectively load retrieval models for indexing/retrieving from saved checkpoints """ args = get_args() model = utils.unwrap_model(model) load_path = custom_load_path if custom_load_path is not None else args.load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ret_state_dict = state_dict['model'] if only_query_model: ret_state_dict.pop('context_model') if only_context_model: ret_state_dict.pop('query_model') assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model
def report_memory(name): """Simple GPU memory report.""" mega_bytes = 1024.0 * 1024.0 string = name + ' memory (MB)' string += ' | allocated: {}'.format( torch.cuda.memory_allocated() / mega_bytes) string += ' | max allocated: {}'.format( torch.cuda.max_memory_allocated() / mega_bytes) string += ' | reserved: {}'.format( torch.cuda.memory_reserved() / mega_bytes) string += ' | max reserved: {}'.format( torch.cuda.max_memory_reserved() / mega_bytes) if mpu.get_data_parallel_rank() == 0: print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
def _set_block_index(self): """Create a Faiss Flat index with inner product as the metric to search against""" try: import faiss except ImportError: raise Exception( "Error: Please install faiss to use FaissMIPSIndex") if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Building index", flush=True) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) if self.use_gpu: # create resources and config for GpuIndex res = faiss.StandardGpuResources() config = faiss.GpuIndexFlatConfig() config.device = torch.cuda.current_device() config.useFloat16 = True self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on GPU {}".format( self.block_mips_index.getDevice()), flush=True) else: # CPU index supports IDs so wrap with IDMap self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built if self.block_data is not None: self.add_block_embed_data(self.block_data)
def get_total_params(model): # Print number of parameters. if mpu.get_data_parallel_rank() == 0: params = sum([p.nelement() for p in model.parameters()]) print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), params), flush=True) else: params = 0 total_n_parameters = torch.tensor([params ]).cuda(torch.cuda.current_device()) torch.distributed.all_reduce(total_n_parameters) total_n_parameters = total_n_parameters.item() return total_n_parameters
def _set_random_seed(seed_, data_parallel_random_init=False): """Set random seed for reproducability.""" if seed_ is not None and seed_ > 0: # Ensure that different pipeline MP stages get different seeds. seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) # Ensure different data parallel ranks get different seeds if data_parallel_random_init: seed = seed + (10 * mpu.get_data_parallel_rank()) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.device_count() > 0: mpu.model_parallel_cuda_manual_seed(seed) else: raise ValueError( 'Seed ({}) should be a positive integer.'.format(seed))
def __init__(self): args = get_args() self.model = None self.dataloader = None self.block_data = None # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint assert not (args.load and args.ict_load) self.using_realm_chkpt = args.ict_load is None self.log_interval = args.indexer_log_interval self.batch_size = args.indexer_batch_size self.load_attributes() self.is_main_builder = mpu.get_data_parallel_rank() == 0 self.num_total_builders = mpu.get_data_parallel_world_size() self.iteration = self.total_processed = 0
def add_embed_data(self, all_embed_data): """Add the embedding of each block to the underlying FAISS index""" # this assumes the embed_data is a dict : {int: np.array<float>} block_indices, block_embeds = zip(*all_embed_data.embed_data.items()) # the embeddings have to be entered in as float32 even though the math # internally is done with float16. embeds_arr = np.float32(np.array(block_embeds)) indices_arr = np.array(block_indices) # we no longer need the embedding data since it's in the index now all_embed_data.clear() self.mips_index.add_with_ids(embeds_arr, indices_arr) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">>> Finished adding block data to index", flush=True)
def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. model = model_provider_func() # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for param in model.parameters(): mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on (tensor, pipeline) ' 'model parallel rank ({}, {}): {}'.format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16Module(model) if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) return model if args.DDP_impl == 'local': model = LocalDDP(model) return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl))
def build_data_loader(dataset, batch_size, num_workers, drop_last): """Data loader. Note that batch-size is the local (per GPU) batch-size.""" # Sampler. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank) # Data loader. Note that batch size is the per GPU batch size. data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False, num_workers=num_workers, drop_last=drop_last, pin_memory=True) return data_loader
def __init__(self): args = get_args() self.model = None self.dataloader = None self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint assert not (args.load and args.ict_load) self.log_interval = args.indexer_log_interval self.batch_size = args.indexer_batch_size self.load_attributes() self.is_main_builder = mpu.get_data_parallel_rank() == 0 self.num_total_builders = mpu.get_data_parallel_world_size() self.iteration = self.total_processed = 0
def __init__(self, model, forward_step_fn, neox_args, batch_size=None): self.cache_hook = base.CacheHook(None) self.model = model self.neox_args = neox_args self.tokenizer = neox_args.tokenizer self._device = torch.device(f"cuda:{neox_args.local_rank}") self._eot_token_id = neox_args.tokenizer.eod_id self._max_length = neox_args.max_position_embeddings // 2 self._max_gen_toks = 128 self._vocab_size = neox_args.padded_vocab_size # parallelism args: self.is_main = neox_args.rank == 0 self.is_local_main = neox_args.local_rank == 0 self.is_model_parallel = neox_args.model_parallel_size > 1 self.is_pipe_parallel = self.model.is_pipe_parallel self.is_data_parallel = self.model.is_data_parallel self.is_last_stage = ( True if not self.is_pipe_parallel else model.is_last_stage() ) # only the last stage of the pipeline model will receive the logits self.dp_world_size = mpu.get_data_parallel_world_size() self.dp_rank = mpu.get_data_parallel_rank() self.dp_group = mpu.get_data_parallel_group() self.is_mp_rank_0 = mpu.get_model_parallel_rank() == 0 self._batch_size = batch_size or ( neox_args.batch_size * self.dp_world_size ) # default batch size to bs per gpu * dp size # some utility functions: # we need to patch tokenizer methods, because lm_eval uses them internally: self.tokenizer.encode = self.tokenizer.tokenize self.tokenizer.decode = self.tokenizer.detokenize self._forward_step_fn = partial( forward_step_fn, neox_args=neox_args, timers=None, return_logits=True ) self.generate = partial( generate_samples_from_prompt, neox_args=neox_args, model=model, maximum_tokens=self._max_gen_toks, temperature=0.0, )
def build_pretraining_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Megatron sampler batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
def model_provider(): """Build the model.""" print_rank_0('building GPT2 model ...') see_memory_usage(f"Before Building Model", force=True) with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), remote_device=get_args().remote_device, deepspeed_config=get_args().deepspeed_config, enabled=get_args().zero_stage == 3): model = GPT2Model(num_tokentypes=0, parallel_output=True) see_memory_usage(f"After Building Model", force=True) if mpu.get_data_parallel_rank() == 0: billion_params = get_parameters_in_billions(model) print( f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\ {round(billion_params, 3)} Billion', flush=True) return model
def make_data_loader(dataset, neox_args): """Buld dataloader given an input dataset.""" if dataset is None: return None # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = neox_args.batch_size * world_size num_workers = neox_args.num_workers # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def get_one_epoch_dataloader(dataset, batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if batch_size is None: batch_size = args.batch_size global_batch_size = batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def get_fmoe_checkpoint_name(checkpoints_path, iteration, release=False, data_parallel_rank=-1): """A unified checkpoint name, allowing specifying a data parallel rank""" from megatron import mpu from megatron.checkpointing import get_checkpoint_name if data_parallel_rank == -1: data_parallel_rank = mpu.get_data_parallel_rank() if data_parallel_rank == 0: return get_checkpoint_name(checkpoints_path, iteration, release) if release: directory = "release" else: directory = "iter_{:07d}".format(iteration) # Use both the tensor and pipeline MP rank. if mpu.get_pipeline_model_parallel_world_size() == 1: return os.path.join( checkpoints_path, directory, "mp_rank_{:02d}_dp_rank_{:04d}".format( mpu.get_tensor_model_parallel_rank(), data_parallel_rank), "model_optim_rng.pt", ) return os.path.join( checkpoints_path, directory, "mp_rank_{:02d}_{:03d}_dp_rank_{:04d}".format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(), data_parallel_rank, ), "model_optim_rng.pt", )
def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() if micro_batch_size is None: micro_batch_size = args.micro_batch_size num_workers = args.num_workers # Use megatron's sampler with consumed samples set to 0 as # this is only for evaluation and don't intend to resume half way. # Also, set the drop last to false as don't intend to remove # the last batch batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=0, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), drop_last=False) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def load_checkpoint(model, optimizer, lr_scheduler): """Load a model checkpoint and return the iteration.""" args = get_args() if isinstance(model, torchDDP): model = model.module # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(args.load) # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_0('WARNING: could not find the metadata file {} '.format( tracker_filename)) print_rank_0(' will not load any checkpoints and will start from ' 'random') return 0 # Otherwise, read the tracker file and either set the iteration or # mark it as a release checkpoint. iteration = 0 release = False with open(tracker_filename, 'r') as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == 'release' if not release: print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( tracker_filename)) sys.exit() assert iteration > 0 or release, 'error parsing metadata file {}'.format( tracker_filename) # Checkpoint. checkpoint_name = get_checkpoint_name(args.load, iteration, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. try: state_dict = torch.load(checkpoint_name, map_location='cpu') except ModuleNotFoundError: # For backward compatibility. print_rank_0(' > deserializing using the old code structure ...') sys.modules['fp16.loss_scaler'] = sys.modules[ 'megatron.fp16.loss_scaler'] state_dict = torch.load(checkpoint_name, map_location='cpu') sys.modules.pop('fp16.loss_scaler', None) except BaseException: print_rank_0('could not load the checkpoint') sys.exit() # Set iteration. if args.finetune or release: iteration = 0 else: try: iteration = state_dict['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict['total_iters'] except KeyError: print_rank_0('A metadata file exists but unable to load ' 'iteration from checkpoint {}, exiting'.format( checkpoint_name)) sys.exit() # Check arguments. if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) else: print_rank_0('could not find arguments in the checkpoint ...') # Model. model.load_state_dict(state_dict['model']) # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(state_dict['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(state_dict['lr_scheduler']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state']) mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the optimizer state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() # Only rank zero of the data parallel writes to the disk. model = utils.unwrap_model(model) print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank( ) == 0: # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = iteration if len(model) == 1: state_dict['model'] = model[0].state_dict_for_save_checkpoint() else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: state_dict['optimizer'] = optimizer.state_dict() if lr_scheduler is not None: state_dict['lr_scheduler'] = lr_scheduler.state_dict() # RNG states. if not args.no_save_rng: state_dict['random_rng_state'] = random.getstate() state_dict['np_rng_state'] = np.random.get_state() state_dict['torch_rng_state'] = torch.get_rng_state() state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() state_dict['rng_tracker_states'] \ = mpu.get_cuda_rng_tracker().get_states() # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) # Wait so everyone is done (necessary) if torch.distributed.is_initialized(): torch.distributed.barrier() print_rank_0( ' successfully saved checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) # And update the latest iteration if not torch.distributed.is_initialized() or torch.distributed.get_rank( ) == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, 'w') as f: f.write(str(iteration)) # Wait so everyone is done (not necessary) if torch.distributed.is_initialized(): torch.distributed.barrier()