def _swap_in_parameter(self, aio_handle, parameter, dest_buffers): swap_info = self._get_param_swap_info(parameter) if swap_info is None: return assert len(swap_info.tensors) <= len(dest_buffers) swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len( swap_info.tensors) swap_buffers = get_sized_buffers(dest_buffers, swap_lengths) READ_TIMER = 'swap_submit_read_param' WAIT_TIMER = 'swap_wait_read_param' self._start_timer(READ_TIMER) swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths) self._stop_timer(READ_TIMER) swap_bytes = sum([ buffer.numel() * buffer.element_size() for buffer in swap_buffers ]) self._start_timer(WAIT_TIMER) aio_handle.wait() self._stop_timer(WAIT_TIMER) compute_lengths = [swap_info.numel()] * len(swap_info.tensors) compute_buffers = get_sized_buffers(dest_buffers, compute_lengths) for t, buffer in zip(swap_info.tensors, compute_buffers): t.data = buffer.data self._log_timers([READ_TIMER, WAIT_TIMER]) if DEBUG_MODE and torch.distributed.get_rank() == 0: logger.info( f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
def _report_statistics(self, message): if torch.distributed.get_rank() == 0: element_size = torch.tensor([], dtype=self.dtype).element_size() swapped_GB = (self.num_elements_swapped * element_size) / (1024**3) logger.info( f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB' )
def _initialize_parameters(self, parameters, src_tensors, aio_handle): assert len(parameters) == len(src_tensors) swap_paths = self._get_swap_paths(parameters=parameters, num_elems=[src.numel() for src in src_tensors]) SWAP_INIT_TIMER = "swap_init_write" self._start_timer(SWAP_INIT_TIMER) pinned_buffers = self.swap_buffer_manager.allocate_all( num_elems=self.largest_numel, dtype=self.dtype) assert pinned_buffers is not None self._swap_out_unpinned_tensors(aio_handle=aio_handle, unpinned_tensors=src_tensors, dest_paths=swap_paths, pinned_buffers=pinned_buffers) if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE: for i, tensor in enumerate(src_tensors): logger.info( f'copy_in_fp16_param: fp32_id = {id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}' ) self.swap_buffer_manager.free(pinned_buffers) self._stop_timer(SWAP_INIT_TIMER) self._log_timers([SWAP_INIT_TIMER])
def main(): start = time.time() args = construct_arguments() model, optimizer = prepare_model_optimizer(args) start_epoch = 0 if not None in [args.load_training_checkpoint, args.load_checkpoint_id]: start_epoch = load_checkpoint(args, model) run(args, model, optimizer, start_epoch) elapsed = time.time() - start logger = args.logger logger.info(f"Elapsed time: {elapsed} seconds")
def _initialize_from_swapped_fp16_params(self, aio_handle, fp16_partitions_info, fp16_num_elems, fp16_pinned_buffers, fp32_parameters): assert len(fp32_parameters) == len(fp16_partitions_info) assert len(fp32_parameters) == len(fp16_num_elems) assert all([buffer.is_pinned() for buffer in fp16_pinned_buffers]) fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters, num_elems=fp16_num_elems) fp32_pinned_buffers = self.swap_buffer_manager.allocate_all( num_elems=self.largest_numel, dtype=self.dtype) fp16_buffer_numel = [buf.numel() for buf in fp16_pinned_buffers] assert all([numel >= self.largest_numel for numel in fp16_buffer_numel]), \ f"numel of fp16 buffers {fp16_buffer_numel} is too small for initializing fp32 params {self.largest_numel}" fp32_swap_buffers = SwapBufferPool(fp32_pinned_buffers) fp16_swap_buffers = SwapBufferPool(fp16_pinned_buffers) curr_index = 0 while curr_index < len(fp32_parameters): fp16_pinned_tensors = self._swap_in_fp16_params( aio_handle=aio_handle, fp16_num_elems=fp16_num_elems[curr_index:], fp16_partitions_info=fp16_partitions_info[curr_index:], fp16_swap_buffers=fp16_swap_buffers) if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE: for i, tensor in enumerate(fp16_pinned_tensors): true_index = curr_index + i logger.info( f'swap_in_fp16_param: fp32_id = {id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}' ) swap_out_count = self._swap_out_fp16_params( aio_handle=aio_handle, fp32_swap_paths=fp32_swap_paths[curr_index:], fp32_swap_buffers=fp32_swap_buffers, fp16_pinned_tensors=fp16_pinned_tensors) assert swap_out_count == len(fp16_pinned_tensors), \ f"{swap_out_count} does not match {len(fp16_pinned_tensors)}" fp16_swap_buffers.reset() fp32_swap_buffers.reset() curr_index += swap_out_count self.swap_buffer_manager.free(fp32_pinned_buffers)
def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer): UNSWAPPED_READ_GRADIENTS = 'unswapped_read_gradients' self._start_timer(UNSWAPPED_READ_GRADIENTS) tensor_count = len(swap_info.unswapped_gradients) num_elem_count = swap_info.read_unswapped_gradients(dest_buffer) self._stop_timer(UNSWAPPED_READ_GRADIENTS) self._log_timers([UNSWAPPED_READ_GRADIENTS]) # It shoud be safe to discard unswapped gradient partitions swap_info.release_unswapped_gradients() if SWAPPER_DEBUG_MODE: logger.info( f'optimizer_retreive_unswapped_radients: param={swap_info.param_id} tensor_count={tensor_count} elem_count={num_elem_count}' )
def swap_out_optimizer_state(self, parameter, async_swap=False): swap_info = self._get_param_swap_info(parameter=parameter) if swap_info is None: return self._start_timer(SWAP_OUT_PARAM_TIMER) pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors( swap_info) swap_bytes = sum([ self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.tensors ]) WRITE_TIMER = 'swap_submit_write' self._start_timer(WRITE_TIMER) swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths) assert self.aio_handle.wait() == len(pinned_tensors) for t in pinned_tensors: t.data = torch.Tensor() if len(unpinned_tensors) > 0: pinned_buffers = self.swap_buffer_manager.allocate_all( num_elems=self.largest_numel, dtype=self.dtype) self._swap_out_unpinned_tensors(aio_handle=self.aio_handle, unpinned_tensors=unpinned_tensors, dest_paths=unpinned_paths, pinned_buffers=pinned_buffers) self.allocated_swap_buffers += pinned_buffers for t in unpinned_tensors: t.data = torch.Tensor() self._stop_timer(WRITE_TIMER) self.swap_buffer_manager.free(self.allocated_swap_buffers) self.allocated_swap_buffers = [] self._stop_timer(SWAP_OUT_PARAM_TIMER) self.timer_names.add(SWAP_OUT_PARAM_TIMER) self._log_timers([WRITE_TIMER]) if DEBUG_MODE and torch.distributed.get_rank() == 0: logger.info( f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB')
def load_checkpoint(args, model): global global_step global global_data_samples global last_global_step_from_restore config = args.config logger = args.logger logger.info( f"Restoring previous training checkpoint from PATH={args.load_training_checkpoint}, CKPT_ID={args.load_checkpoint_id}" ) start_epoch, global_step, global_data_samples = load_training_checkpoint( args=args, model=model, PATH=args.load_training_checkpoint, ckpt_id=args.load_checkpoint_id) logger.info( f"The model is loaded from last checkpoint at epoch {start_epoch} when the global steps were at {global_step} and global data samples at {global_data_samples}" ) if args.rewarmup: logger.info( f"Rewarmup learning rate with last_global_step_from_restore = {global_step}" ) last_global_step_from_restore = global_step lr_this_step = config["training"][ "learning_rate"] * warmup_linear_decay_exp( global_step, config["training"]["decay_rate"], config["training"]["decay_step"], config["training"]["total_training_steps"], config["training"]["warmup_proportion"]) logger.info(f"Restart training with lr = {lr_this_step}") # Run validation for checkpoint before training if not args.finetune and args.max_seq_length == 512: logger.info( f"Validation Loss of Checkpoint {start_epoch} before pretraining") index = start_epoch - 1 if start_epoch > 0 else start_epoch pretrain_validation(args, index, model) return start_epoch
def run(args, model, optimizer, start_epoch): global global_step global global_data_samples global last_global_step_from_restore config = args.config logger = args.logger if args.use_nvidia_dataset: pretrain_dataset_provider = NvidiaBertDatasetProvider(args) else: pretrain_dataset_provider = BingBertDatasetProvider(args) for index in range(start_epoch, config["training"]["num_epochs"]): logger.info(f"Training Epoch: {index + 1}") pre = time.time() train(args, index, model, optimizer, pretrain_dataset_provider) # Save ckpts according to "--ckpt_to_save" option, # e.g. "--ckpt_to_save 160 161" to save epoch 160 and 161. if args.ckpt_to_save is None or (index + 1) in args.ckpt_to_save: logger.info( f"Saving a checkpointing of the model for epoch: {index+1}") checkpoint_model(PATH=args.saved_model_path, ckpt_id='epoch{}_step{}'.format( index + 1, global_step), model=model, epoch=index + 1, last_global_step=global_step, last_global_data_samples=global_data_samples) post = time.time() logger.info(f"Time for shard {index + 1}: {post-pre} seconds") current_global_step = global_step - last_global_step_from_restore if is_time_to_exit(args=args, global_steps=current_global_step): print( f'Warning: Early training termination due to max steps limit, epoch={index+1}, global_step={current_global_step}' ) break
def __init__(self, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) assert isinstance(self.module, PipelineModule), "model must base PipelineModule" # We schedule the all-reduces, so disable it in super().backward() self.enable_backward_allreduce = False assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ " with pipeline parallelism." # pipeline step for logging self.log_batch_step_id = -1 self.micro_batch_size = self.train_micro_batch_size_per_gpu() self.micro_batches = self.gradient_accumulation_steps() # Set Grid and Communication Groups self.grid = self.module._grid if self.grid.get_global_rank() == 0: logger.info(f'CONFIG: micro_batches={self.micro_batches} ' f'micro_batch_size={self.micro_batch_size}') self.global_rank = self.grid.get_global_rank() assert self.dp_world_size == self.grid.data_parallel_size assert self.train_batch_size() == \ self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size # Set Stage Inf self.num_stages = self.grid.pipe_parallel_size self.stage_id = self.grid.get_stage_id() self.prev_stage = self.stage_id - 1 self.next_stage = self.stage_id + 1 self.data_iterator = None self.batch_fn = None self._force_grad_boundary = False self.batch_timer = ThroughputTimer( batch_size=self.micro_batch_size * self.micro_batches, num_workers=self.dp_world_size, logging_fn=self.tput_log, monitor_memory=False, steps_per_output=self.steps_per_print()) # PipelineEngine needs to handle data loading specially due to only the first # and last stages loading inputs/labels. We construct a sampler that uses if self.training_data: self._build_data_iter(self.training_data) self.is_pipe_parallel = self.grid.pipe_parallel_size > 1 self.is_data_parallel = self.grid.data_parallel_size > 1 self.is_model_parallel = self.grid.model_parallel_size > 1 # Partition input/output buffers self.is_pipe_partitioned = self.is_model_parallel self.is_grad_partitioned = False model_parameters = filter(lambda p: p.requires_grad, self.module.parameters()) num_params = sum([p.numel() for p in model_parameters]) unique_params = num_params # Subtract tied parameters if we don't own them if self.module.tied_comms: tied_params = 0 for key, d in self.module.tied_comms.items(): if self.global_rank != min(d['ranks']): tied_params += sum(p.numel() for p in d['module'].parameters()) unique_params -= tied_params params_tensor = torch.LongTensor(data=[num_params, unique_params]).to( self.device) dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group()) params_tensor = params_tensor.tolist() total_params = params_tensor[0] unique_params = params_tensor[1] if self.grid.data_parallel_id == 0: logger.info( f'RANK={self.global_rank} ' f'STAGE={self.stage_id} ' f'LAYERS={self.module._local_stop - self.module._local_start} ' f'[{self.module._local_start}, {self.module._local_stop}) ' f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) ' f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) ' f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)') #intialize peer-2-peer communication and allreduce groups if self.is_pipe_parallel: p2p.init_process_groups(self.grid) # Pipeline buffers self.num_pipe_buffers = 0 self.pipe_buffers = { 'inputs': [], # batch input and received activations 'labels': [], # labels from batch input 'outputs': [], # activations 'output_tensors': [], # tensor object to preserve backward graph } self.pipe_recv_buf = None self.grad_layer = None self.meta_buffer = None self.first_output_send = True self.first_gradient_send = True #stores the loss for the current micro batch being processed self.loss = torch.tensor(0.0).to(self.device) #stores the loss for the entire batch self.total_loss = None self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device) self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device) if self._config.pipeline['activation_checkpoint_interval'] > 0: self.module.activation_checkpoint_interval = self._config.pipeline[ 'activation_checkpoint_interval'] if self.is_last_stage(): self.loss_model = self.module.loss_fn # Initialize pipeline communicators. Just send a 0. if is_even(self.stage_id): if not self.is_last_stage(): p2p.send(self.loss, self.next_stage) if not self.is_first_stage(): p2p.recv(self.loss, self.prev_stage) else: if not self.is_first_stage(): p2p.recv(self.loss, self.prev_stage) if not self.is_last_stage(): p2p.send(self.loss, self.next_stage) # XXX look into timer reporting timing # Initialize some timers because of early weirdness. if self.wall_clock_breakdown(): self.timers('forward_microstep').start() self.timers('forward_microstep').stop() self.timers('backward_microstep').start() self.timers('backward_microstep').stop() self.timers('backward_inner_microstep').start() self.timers('backward_inner_microstep').stop() self.timers('backward_allreduce_microstep').start() self.timers('backward_allreduce_microstep').stop() self.timers('backward_allreduce').start() self.timers('backward_allreduce').stop() self.timers('step_microstep').start() self.timers('step_microstep').stop()
def train(args, index, model, optimizer, pretrain_dataset_provider, finetune=False): global global_step global global_data_samples global last_global_step_from_restore dataset_iterator, total_length = pretrain_dataset_provider.get_shard(index) current_data_sample_count = global_data_samples config = args.config logger = args.logger logger.info( f'worker-{dist.get_rank()}: begin epoch {index+1} current_sample_count {current_data_sample_count} shard_length {total_length} global_data_samples {global_data_samples}' ) pretrain_dataset_provider.prefetch_shard(index + 1) model.train() epoch_step = 0 rounds = 20 all_step_time = 0.0 step_counts = 0 for _, batch_index in enumerate(tqdm(dataset_iterator, smoothing=1)): try: step_start = time.time() batch = pretrain_dataset_provider.get_batch(batch_index) batch = tuple(t.to(args.device) for t in batch) # Move to GPU # Calculate forward pass loss = model.network(batch) unscaled_loss = loss.item() current_data_sample_count += (args.train_micro_batch_size_per_gpu * dist.get_world_size()) # Prefetch training data pretrain_dataset_provider.prefetch_batch() model.network.backward(loss) loss = None if model.network.is_gradient_accumulation_boundary(): if args.fp16: # modify learning rate with special warm up BERT uses # if args.fp16 is False, BertAdam is used that handles this automatically lr_this_step = update_learning_rate( args, config, global_step, optimizer) report_step_metrics(args, lr_this_step, unscaled_loss, global_step, current_data_sample_count) model.network.step() report_lamb_coefficients(args, optimizer) global_step += 1 epoch_step += 1 else: # Call DeepSpeed engine step on micro steps model.network.step() except StopIteration: continue current_global_step = global_step - last_global_step_from_restore if is_time_to_exit(args=args, epoch_steps=epoch_step, global_steps=current_global_step): print( f'Warning: Early epoch termination due to max steps limit, epoch step ={epoch_step}, global step = {current_global_step}, epoch = {index+1}' ) break step_time = time.time() - step_start all_step_time += step_time if global_step % rounds == 0 and global_step != 0 and model.network.is_gradient_accumulation_boundary( ) and dist.get_rank() == 0: one_step_bs = args.train_micro_batch_size_per_gpu * args.gradient_accumulation_steps * dist.get_world_size( ) * rounds print(' At step {}, the throughput is {:2f} Samples/s'.format( global_step * args.gradient_accumulation_steps, one_step_bs / all_step_time), flush=True) all_step_time = 0.0 pretrain_dataset_provider.release_shard(index) global_data_samples = current_data_sample_count # Run Validation Loss if not finetune and args.max_seq_length == 512: pretrain_validation(args, index, model)
def print_object(obj, name, exclude_list=[]): logger.info('{}:'.format(name)) for arg in sorted(vars(obj)): if not arg in exclude_list: dots = '.' * (29 - len(arg)) logger.info(' {} {} {}'.format(arg, dots, getattr(obj, arg)))