class FP16_UnfusedOptimizer(object): """ FP16 Optimizer without weight fusion to support LAMB optimizer For usage example please see, TODO: DeepSpeed V2 Tutorial """ def __init__(self, init_optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True, mpu=None, clip_grad=0.0, fused_lamb_legacy=False): self.fused_lamb_legacy = fused_lamb_legacy if torch.distributed.get_rank() == 0: logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ') if not torch.cuda.is_available: raise SystemError("Cannot use fp16 without CUDA.") self.optimizer = init_optimizer # param groups self.fp16_groups = [] self.fp32_groups = [] # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): #fp16 weights that represents the actual model weights self.fp16_groups.append(param_group['params']) #creating a fp32 copy of the weights that will be updated first then #copied to fp16 weights fp32_group = [p.clone().float().detach() for p in param_group['params']] #incase the internal optimizer needs it for p in fp32_group: p.requires_grad = True #setting the param groups in the optimizer to point to fp32 #note these are not the weights used by the model #the model uses the fp16 version that we added to fp16_group self.fp32_groups.append(fp32_group) param_group['params'] = self.fp32_groups[i] # we may have a way of fusing dynamic scale. Do not support for now if dynamic_loss_scale: self.dynamic_loss_scale = True self.cur_iter = 0 self.last_overflow_iter = -1 self.scale_factor = 2.0 if dynamic_loss_args is None: self.cur_scale = 1.0 * 2**16 self.scale_window = 1000 self.min_loss_scale = 0.25 else: self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE] self.scale_window = dynamic_loss_args[SCALE_WINDOW] self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE] else: self.dynamic_loss_scale = False self.cur_iter = 0 self.cur_scale = static_loss_scale self.verbose = verbose self.clip_grad = clip_grad self.norm_type = 2 TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) if TORCH_MAJOR == 0 and TORCH_MINOR <= 4: self.clip_grad_norm = torch.nn.utils.clip_grad_norm else: self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ self.mpu = mpu self.overflow = False self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu) self.initialize_optimizer_states() def zero_grad(self, set_grads_to_None=True): """ Zero FP16 parameter grads. """ # FP32 grad should never exist outside of the step function # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: if set_grads_to_None: p.grad = None else: if p.grad is not None: p.grad.detach_() p.grad.zero_() def step_fused_lamb(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] grads_groups = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads = [ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ] grads_groups.append(grads) grads_groups_flat.append(_flatten_dense_tensors(grads)) norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False) self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale) return self.overflow def step(self, closure=None): """ Not supporting closure. """ if self.fused_lamb_legacy: return self.step_fused_lamb() self.overflow = self.overflow_checker.check() prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow norm_groups = [] for i, group in enumerate(self.fp16_groups): norm_groups.append(get_grad_norm(group, mpu=self.mpu)) # copying gradients to fp32 to work with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self.unscale_and_clip_grads(norm_groups) self.optimizer.step() for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for fp32_param, fp16_param in zip(fp32_group, fp16_group): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow def unscale_and_clip_grads(self, norm_groups, apply_scale=True): total_norm = 0.0 for norm in norm_groups: total_norm += norm**2.0 total_norm = math.sqrt(total_norm) # compute combined scale factor for this group combined_scale = self.cur_scale if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad if clip > 1: combined_scale = clip * self.cur_scale if apply_scale: for group in self.fp32_groups: for param in group: if param.grad is not None: param.grad.data.mul_(1. / combined_scale) return combined_scale def backward(self, loss): """ :attr:`backward` performs the following steps: 1. fp32_loss = loss.float() 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ scaled_loss = (loss.float()) * self.cur_scale scaled_loss.backward() def _update_scale(self, skip): if self.dynamic_loss_scale: prev_scale = self.cur_scale if skip: self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale) self.last_overflow_iter = self.cur_iter if self.verbose: logger.info("Grad overflow on iteration: %s", self.cur_iter) logger.info( f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}" ) else: # Ensure self.scale_window updates since last overflow stable_interval = (self.cur_iter - self.last_overflow_iter) - 1 if (stable_interval > 0) and (stable_interval % self.scale_window == 0): self.cur_scale *= self.scale_factor if self.verbose: logger.info( f"No Grad overflow for {self.scale_window} iterations") logger.info( f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}" ) else: if skip: logger.info("Grad overflow on iteration %s", self.cur_iter) logger.info("Using static loss scale of %s", self.cur_scale) self.cur_iter += 1 return # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" def _get_state(self): return self.optimizer.state def _set_state(self, value): self.optimizer.state = value state = property(_get_state, _set_state) # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): return self.optimizer.param_groups def _set_param_groups(self, value): self.optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups) def state_dict(self): """ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict of the contained Pytorch optimizer. Example:: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, "saved.pth") """ state_dict = {} state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['cur_scale'] = self.cur_scale state_dict['cur_iter'] = self.cur_iter if state_dict['dynamic_loss_scale']: state_dict['last_overflow_iter'] = self.last_overflow_iter state_dict['scale_factor'] = self.scale_factor state_dict['scale_window'] = self.scale_window state_dict['optimizer_state_dict'] = self.optimizer.state_dict() state_dict['fp32_groups'] = self.fp32_groups return state_dict def load_state_dict(self, state_dict, load_optimizer_states=True): """ Loads a state_dict created by an earlier call to state_dict(). If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, whose parameters in turn came from ``model``, it is expected that the user will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: model = torch.nn.Linear(D_in, D_out).cuda().half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... checkpoint = torch.load("saved.pth") model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) """ # I think it should actually be ok to reload the optimizer before the model. self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.cur_scale = state_dict['cur_scale'] self.cur_iter = state_dict['cur_iter'] if state_dict['dynamic_loss_scale']: self.last_overflow_iter = state_dict['last_overflow_iter'] self.scale_factor = state_dict['scale_factor'] self.scale_window = state_dict['scale_window'] if load_optimizer_states: self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still # out of date. There are two options. # 1: Refresh the master params from the model's fp16 params. # This requires less storage but incurs precision loss. # 2: Save and restore the fp32 master copies separately. # We choose option 2. # # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device # of their associated parameters, because it's possible those buffers might not exist yet in # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # constructed in the same way as the one whose state_dict we are loading, the same master params # are guaranteed to exist, so we can just copy_() from the saved master params. for current_group, saved_group in zip(self.fp32_groups, state_dict['fp32_groups']): for current, saved in zip(current_group, saved_group): current.data.copy_(saved.data) def __repr__(self): return repr(self.optimizer) def initialize_optimizer_states(self): for i, group in enumerate(self.fp16_groups): for param in group: param.grad = torch.zeros(param.size(), dtype=param.dtype, device=torch.cuda.current_device()) for i, group in enumerate(self.fp32_groups): for param in group: param.grad = torch.zeros(param.size(), dtype=param.dtype, device=torch.cuda.current_device()) self.optimizer.step() for i, group in enumerate(self.fp16_groups): for param in group: param.grad = None for i, group in enumerate(self.fp32_groups): for param in group: param.grad = None
class FP16_DeepSpeedZeroOptimizer_Stage1(object): """ FP16_DeepSpeedZeroOptimizer_Stage1 designed to reduce the memory footprint required for training large deep learning models. For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models https://arxiv.org/abs/1910.02054 This version aligns with stage-1 in the paper above. """ def __init__(self, init_optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True, dp_process_group=None, partition_size=None, mpu=None, all_gather_partitions=True, allgather_size=500000000, clip_grad=0.0, max_elements_per_comm=5e8): if dp_process_group is not None and partition_size is not None: raise ValueError("Cannot specify both dp_process_group " "and partition size") if dp_process_group is None: dp_process_group = _initialize_parameter_parallel_groups(partition_size) if not torch.cuda.is_available: raise SystemError("Cannot use fp16 without CUDA.") self.optimizer = init_optimizer self.verbose = verbose self.dp_process_group = dp_process_group # TODO: automatically turn off if #params > some_limit self.all_gather_partitions = all_gather_partitions self.allgather_size = allgather_size self.max_elements_per_comm = max_elements_per_comm logger.info("max_elements_per_comm={}".format(max_elements_per_comm)) # param flattened by groups self.fp16_groups = [] self.fp16_groups_flat = [] # Setup bookkeeping data structures depending on partitioning type # parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids] self.parallel_sub_partitioned_fp16_groups = [] # same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids] self.parallel_comm_sub_partitioned_fp16_groups = [] # 32-bit sub-partitions of the parallel partitioned parameters # that this process will update self.local_sub_partitions_of_fp32_groups = [] # param partition info # parameters in each group that will not be updated by this process directly self.params_not_local = [] # parameters that will be updated by this process directly self.params_in_rank_sub_partitions = [] # parameter offsets for parameters in sub-partitions. Parameter # boundaries may not align with sub-partition boundaries # so we need to keep track of the offsets self.params_in_rank_sub_partitions_offsets = [] # number of elements per sub-partition in each group self.sub_partition_sizes = [] # number of communication intervals for each group self.num_comm_intervals_per_group = [] local_rank = dist.get_rank(group=self.dp_process_group) self.group_paddings = [] self.partition_count = dist.get_world_size(group=self.dp_process_group) # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): # push this group to list before modify self.fp16_groups.append(param_group['params']) # flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing # RS: create aligned sub-partitions self.fp16_groups_flat.append( flatten_dense_tensors_sub_partition_aligned( tensor_list=self.fp16_groups[i], dp=dist.get_world_size(group=self.dp_process_group), max_elements_per_comm=self.max_elements_per_comm, pg=self.dp_process_group)) # TODO: I don't think this does anything? # set model fp16 weight to slices of flattened buffer updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data # divide the flat weights into near equal partition equal to the data parallel degree # each process will compute on a different part of the partition # RS: split into two layer list -> [comm-id] -> [sub-partitions per rank] comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \ self.get_data_parallel_sub_partitions( tensor=self.fp16_groups_flat[i], max_elements_per_comm=self.max_elements_per_comm, world_size=dist.get_world_size( group=self.dp_process_group), dp_process_group=self.dp_process_group ) self.parallel_comm_sub_partitioned_fp16_groups.append( comm_partitions) # comm -> rank self.parallel_sub_partitioned_fp16_groups.append( dp_sub_partitions) # rank -> comm self.sub_partition_sizes.append(sub_partition_size) self.num_comm_intervals_per_group.append(num_comm_intervals) # data_parallel_partitions = self.get_data_parallel_partitions(self.fp16_groups_flat[i]) # self.parallel_partitioned_fp16_groups.append(data_parallel_partitions) # a partition of the fp32 master weights that will be updated by this process # RS: store/detach/cast our local sub-partitions local_sub_partitions = [] for sub_partition in self.parallel_sub_partitioned_fp16_groups[i][ local_rank]: fp32_sub_partition = sub_partition.clone().float().detach() fp32_sub_partition.requires_grad = True local_sub_partitions.append(fp32_sub_partition) self.local_sub_partitions_of_fp32_groups.append(local_sub_partitions) # Compute sub_partition paddings sub_partition_paddings = get_group_alignment_padding( tensor_list=self.fp16_groups[i], sub_partition_size=sub_partition_size, sub_partition_count=num_comm_intervals * self.partition_count) self.group_paddings.append(sub_partition_paddings) # modify optimizer of have flat master weight # self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it param_group['params'] = self.local_sub_partitions_of_fp32_groups[i] # RS: divide up the sub-partitions and keep track of offsets for each param # partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group) params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, \ params_not_local = self.get_all_sub_partition_info( tensor_list=self.fp16_groups[i], all_element_intervals=element_intervals, local_rank=local_rank, world_size=dist.get_world_size(group=self.dp_process_group) ) self.params_in_rank_sub_partitions.append(params_in_rank_sub_partition) self.params_not_local.append(params_not_local) self.params_in_rank_sub_partitions_offsets.append( params_in_rank_sub_partitions_offsets) # we may have a way of fusing dynamic scale. Do not support for now if dynamic_loss_scale: if dynamic_loss_args is None: self.loss_scaler = DynamicLossScaler() else: self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) self.dynamic_loss_scale = True else: self.dynamic_loss_scale = False self.loss_scaler = LossScaler(scale=static_loss_scale) self.cur_iter = 0 self.mpu = mpu self.clip_grad = clip_grad self.overflow = False self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, zero_reduce_scatter=True) self._initialize_optimizer_states() def _initialize_optimizer_states(self): for group_idx, group in enumerate(self.local_sub_partitions_of_fp32_groups): for idx, sub_partition_param in enumerate(group): sub_partition_grad = torch.zeros(int( self.sub_partition_sizes[group_idx]), dtype=sub_partition_param.dtype).cuda() sub_partition_param.grad = sub_partition_grad self.optimizer.step() for group in self.local_sub_partitions_of_fp32_groups: for idx, sub_partition_param in enumerate(group): sub_partition_param.grad = None @staticmethod def get_data_parallel_sub_partitions(tensor, max_elements_per_comm, world_size, dp_process_group=None): total_num_elements = tensor.numel() # if total elements is less than our max, revert to splitting into dp partitions max_elements_per_comm = min(total_num_elements, max_elements_per_comm) sub_partition_size = int(max_elements_per_comm // world_size) # Ensure partition alignment was done correctly num_sub_partitions = int(total_num_elements // sub_partition_size) assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements, sub_partition_size) # Ensure comm interval alignment was done correctly. num_comm_intervals = int(num_sub_partitions // world_size) assert num_sub_partitions % world_size == 0, "{} % {} != 0".format(num_sub_partitions, world_size) if not dist.is_initialized() or dist.get_rank(group=dp_process_group) == 0: logger.info("**** partition info:") logger.info("\t total_num_elements=%s", total_num_elements) logger.info("\t world_size=%s", world_size) logger.info("\t max_elements_per_comm=%s", max_elements_per_comm) logger.info("\t sub_partition_size=%s", sub_partition_size) logger.info("\t num_sub_partitions=%s", num_sub_partitions) logger.info("\t num_comm_intervals=%s", num_comm_intervals) logger.info("****") # [comm_id] -> [rank] comm_partitions = [] for _ in range(num_comm_intervals): comm_partitions.append([]) start = 0 comm_id = 0 element_intervals = defaultdict( list) # [rank] -> [(start,end), (start,end), ...] for idx in range(num_sub_partitions): rank_id = idx % world_size sub_partition = tensor.narrow(0, start, sub_partition_size).detach() element_intervals[rank_id].append((start, start + sub_partition_size)) comm_partitions[comm_id].append(sub_partition) start = start + sub_partition_size if rank_id == (world_size - 1): comm_id += 1 # [rank] -> [comm_id] sub_partitions = [] for _ in range(world_size): sub_partitions.append([]) for comm_id, partitions in enumerate(comm_partitions): for rank_id, partition in enumerate(partitions): sub_partitions[rank_id].append(partition) return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals @staticmethod def get_all_sub_partition_info(tensor_list, all_element_intervals, local_rank, world_size): params_not_local = [] # [rank] -> [comm-id] -> [param/offset] params_in_rank_sub_partition = [] params_in_rank_sub_partitions_offsets = [] for rank in range(world_size): params_in_local_sub_partition = [] local_sub_partition_offsets = [] comm_tensor_list = [] comm_offset_list = [] current_index = 0 prev_comm_idx = 0 for iii, tensor in enumerate(tensor_list): tensor_size = tensor.numel() #if local_rank == 0: # # logger.info("rank={}, current_index={}, tensor_size={}, tensor-idx={}".format(rank, # current_index, tensor_size, iii)) results_list = _range_check(current_index, all_element_intervals[rank], tensor_size) for contained, offset, comm_idx in results_list: #if local_rank == 0: # logger.info("rank={}, contained={}, offset={}, comm_idx={}".format(rank, contained, # offset, comm_idx)) if contained: if prev_comm_idx != comm_idx: params_in_local_sub_partition.append(comm_tensor_list) comm_tensor_list = [] local_sub_partition_offsets.append(comm_offset_list) comm_offset_list = [] comm_tensor_list.append(tensor) comm_offset_list.append(offset) prev_comm_idx = comm_idx elif rank == local_rank: params_not_local.append(tensor) current_index = current_index + tensor_size #assert len(comm_tensor_list) > 0 #assert len(comm_offset_list) > 0 params_in_local_sub_partition.append(comm_tensor_list) local_sub_partition_offsets.append(comm_offset_list) params_in_rank_sub_partition.append(params_in_local_sub_partition) params_in_rank_sub_partitions_offsets.append(local_sub_partition_offsets) return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local @staticmethod def get_flat_sub_partitions(comm_tensor_list, comm_param_offsets, sub_partition_size, dtype, num_comm_intervals=None, default_device=None, return_partition_params=False): partition_params = [] final_param_offsets = [] flat_sub_partitions = [] for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets): flat_tensor_list = [] current_size = 0 my_offsets = [] my_params = [] if dtype is None: dtype = tensor_list[0].dtype for i, tensor in enumerate(tensor_list): if tensor.grad is None: tensor.grad = torch.zeros(tensor.size(), dtype=tensor.dtype, device=tensor.device) param = tensor tensor = tensor.grad num_elements = tensor.numel() tensor_offset = 0 #we need to offset to get to the right element if i == 0 and param_offsets[i] > 0: tensor_offset = param_offsets[i] num_elements = num_elements - tensor_offset # We don't need all elements of the tensor if this tensor is # larger than we have space for in our curr sub-partition if num_elements > (sub_partition_size - current_size): num_elements = sub_partition_size - current_size #we need a narrow view of the tensor based on the tensor offset and number of elements that #we need from this tensor if tensor_offset > 0 or num_elements < tensor.numel(): flat_tensor_list.append(tensor.contiguous().view(-1).narrow( 0, int(tensor_offset), int(num_elements)).to(dtype)) else: flat_tensor_list.append(tensor.to(dtype)) my_params.append(param) #remember offset into partition and #elems for this tensor my_offsets.append((current_size, num_elements)) current_size = current_size + num_elements #this means its the last partition and does not align with the dp boundary. We need to pad before flattening if current_size < sub_partition_size: my_offsets.append((None, None)) my_params.append(None) if len(tensor_list) == 0: assert default_device != None flat_tensor_list.append( torch.zeros(int(sub_partition_size - current_size), dtype=dtype, device=default_device)) else: flat_tensor_list.append( torch.zeros(int(sub_partition_size - current_size), dtype=dtype, device=tensor_list[0].device)) partition_params.append(my_params) #flat_tensor_list) final_param_offsets.append(my_offsets) assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets)) flat_sub_partitions.append(_flatten_dense_tensors(flat_tensor_list)) if num_comm_intervals is not None and len( flat_sub_partitions) < num_comm_intervals: # logger.info("padding w. sub partitions to ensure uniform communication") device = flat_sub_partitions[0].device for _ in range(num_comm_intervals - len(flat_sub_partitions)): flat_sub_partitions.append( torch.zeros(int(sub_partition_size), dtype=dtype, device=device)) partition_params.append([None]) final_param_offsets.append([(None, None)]) if return_partition_params: assert len(flat_sub_partitions) == len(partition_params) assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params), len(final_param_offsets)) return flat_sub_partitions, partition_params, final_param_offsets return flat_sub_partitions def zero_grad(self, set_grads_to_None=True): """ Zero FP16 parameter grads. """ # FP32 grad should never exist. # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: if set_grads_to_None: p.grad = None else: if p.grad is not None: p.grad.detach_() p.grad.zero_() def free_grad_in_param_list(self, param_list): for p in param_list: if isinstance(p, list): for _p in p: _p.grad = None else: p.grad = None def reduce_scatter_gradients(self, postscale_gradients, gradient_predivide_factor, gradient_average): world_size = dist.get_world_size(group=self.dp_process_group) local_rank = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): partition_param_map = {} param_partition_map = {} my_params = set() # [rank] -> [comm] -> partition num_comm_intervals = self.num_comm_intervals_per_group[i] all_sub_partitions = [] for rank in range(world_size): # gsp is list of partitions indexed by comm_idx #FIXME: currently hardcoding fp16, should infer dtype grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions( comm_tensor_list=self.params_in_rank_sub_partitions[i][rank], comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank], sub_partition_size=self.sub_partition_sizes[i], dtype=torch.half, #self.params_in_rank_sub_partitions[i][rank][0][0].dtype, num_comm_intervals=self.num_comm_intervals_per_group[i], default_device='cuda', #self.params_in_rank_sub_partitions[i][rank][0][0].device, return_partition_params=True) all_sub_partitions.append(grad_sub_partitions) # create map from partition -> params in that partition for comm_idx, part in enumerate(grad_sub_partitions): partition_param_map[part] = (partition_params[comm_idx], param_offsets[comm_idx]) for comm_idx, params in enumerate(partition_params): for pidx, p in enumerate(params): # store the parameters we care about locally if rank == local_rank: my_params.add(p) # map from param -> partitions if p in param_partition_map: param_partition_map[p].append(grad_sub_partitions[comm_idx]) else: param_partition_map[p] = [grad_sub_partitions[comm_idx]] assert len(grad_sub_partitions) == num_comm_intervals if not postscale_gradients: raise NotImplementedError("pre-scale_gradients is not implemented") all_comm_partitions = [] for comm_idx in range(num_comm_intervals): single_comm_all_partitions = [] for rank in range(world_size): single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx]) dist.reduce_scatter(output=single_comm_all_partitions[local_rank], input_list=single_comm_all_partitions, group=self.dp_process_group) if gradient_average: for partition in single_comm_all_partitions: partition.mul_(gradient_predivide_factor / world_size) all_comm_partitions.append(single_comm_all_partitions) for p in my_params: partitions = param_partition_map[p] parts = [] for part in partitions: params, offsets = partition_param_map[part] found = False for p_idx, _p in enumerate(params): if p.__hash__() == _p.__hash__(): found = True if offsets[p_idx][0] is not None: my_part = part.narrow(0, offsets[p_idx][0], offsets[p_idx][1]) parts.append(my_part) assert found if p is not None: updated_grad = _unflatten_dense_tensors(torch.cat(parts), [p]) p.grad.copy_(updated_grad[0]) def step(self, closure=None): # First compute norm for all group so we know if there is overflow self.overflow = self.overflow_checker.check() prev_scale = self.loss_scale self._update_scale(self.overflow) if self.overflow: self.zero_grad() if self.verbose: logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format( prev_scale, self.loss_scale)) return self.overflow norm_groups = [] local_sub_partitions_grad_groups = [] partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): #TODO RS: update get grad norm to support sub partitions norm_groups.append(get_grad_norm(group, mpu=self.mpu)) #RS: update free grads w.r.t. sub partitions #free gradients for all the parameters that are not updated by this process self.free_grad_in_param_list(self.params_not_local[i]) #create flat gradients for parameters updated by this process #tensor_list, first_offset, partition_size, dtype #single_grad_partition = self.get_flat_partition( # tensor_list=self.params_in_partition[i], # first_offset=self.first_offset[i], # partition_size=self.partition_size[i], # dtype=self.single_partition_of_fp32_groups[i].dtype #) #TODO RS: can we safely use dtype of the first sub-partition? i think so local_grad_sub_partitions = self.get_flat_sub_partitions( comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id], comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i] [partition_id], sub_partition_size=self.sub_partition_sizes[i], dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype, num_comm_intervals=self.num_comm_intervals_per_group[i], default_device=self.local_sub_partitions_of_fp32_groups[i][0].device) #RS: update all our local params with sub-partition grads #logger. info("self.local_sub_partitions_of_fp32_groups[i]={}, local_grad_sub_partitions={}".format(len(self.local_sub_partitions_of_fp32_groups[i]), len(local_grad_sub_partitions))) for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]): sub_partition_param.grad = local_grad_sub_partitions[idx] #self.single_partition_of_fp32_groups[i].grad = single_grad_partition #RS: update free grads for sub-partitions #release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list( self.params_in_rank_sub_partitions[i][partition_id]) local_sub_partitions_grad_groups.append(local_grad_sub_partitions) #RS: update unscale/clip with sub partitions self.unscale_and_clip_grads(local_sub_partitions_grad_groups, norm_groups) self.optimizer.step() #RS: clear our sub partition grads #get rid of the fp32 gradients. Not needed anymore for group in self.local_sub_partitions_of_fp32_groups: for idx, sub_partition_param in enumerate(group): sub_partition_param.grad = None #group.grad = None #NOTE RS: removed norm_groups outer loop from original code, i don't think it's needed #RS: copy all sub-partition fp32 data to fp16 sub partitions # copy fp32 param data to fp16 partitions w.r.t. our local rank for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups): for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions): local_sub_partition_param_fp16.data.copy_( local_sub_partition_param_fp32.data) #RS: all_gather/broadcast sub-partitions in separate comm calls #gather the updated weights from everyone for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups: for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions): dist.all_gather(sub_partitions, sub_partitions[partition_id], group=self.dp_process_group) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow def unscale_and_clip_grads(self, grad_groups_flat, norm_groups): total_norm = 0.0 for norm in norm_groups: total_norm += norm**2.0 total_norm = math.sqrt(total_norm) # compute combined scale factor for this group combined_scale = self.loss_scale if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad if clip > 1: combined_scale = clip * self.loss_scale for grad in grad_groups_flat: if isinstance(grad, list): sub_partitions = grad for g in sub_partitions: g.data.mul_(1. / combined_scale) else: grad.data.mul_(1. / combined_scale) def backward(self, loss, retain_graph=False): self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) def _update_scale(self, has_overflow=False): self.loss_scaler.update_scale(has_overflow) # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" def _get_state(self): return self.optimizer.state def _set_state(self, value): self.optimizer.state = value state = property(_get_state, _set_state) # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): return self.optimizer.param_groups def _set_param_groups(self, value): self.optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups) # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" def _get_loss_scale(self): return self.loss_scaler.loss_scale def _set_loss_scale(self, value): self.loss_scaler.cur_scale = value loss_scale = property(_get_loss_scale, _set_loss_scale) cur_scale = property(_get_loss_scale, _set_loss_scale) # Return group tensor after removing paddings that are added for alignment to DP world size. # This method works on the assumption that each group contains sub partitions. def _get_groups_without_padding(self, groups_with_padding): groups_without_padding = [] local_rank = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(groups_with_padding): low_index = local_rank * len(group) high_index = (local_rank + 1) * len(group) group_paddings = self.group_paddings[i][low_index:high_index] lean_sub_partitions = [] for j, sub_partition in enumerate(group): lean_length = sub_partition.numel() - group_paddings[j] lean_sub_partitions.append(sub_partition[:lean_length]) groups_without_padding.append(lean_sub_partitions) return groups_without_padding # Return optimizer state after removing paddings that are added for alignment. def _get_state_without_padding(self, state_with_padding, padding): lean_state = {} for key, value in state_with_padding.items(): lean_length = value.numel() - padding lean_state[key] = value[:lean_length] return lean_state # Return base optimizer states. # This method assumes that each param group contains a single flattened tensor. def _get_base_optimizer_state(self): optimizer_groups_state = [] local_rank = dist.get_rank(group=self.dp_process_group) for group_idx, group in enumerate(self.optimizer.param_groups): group_lean_state = [] low_index = local_rank * self.num_comm_intervals_per_group[group_idx] high_index = (local_rank + 1) * self.num_comm_intervals_per_group[group_idx] param_paddings = self.group_paddings[group_idx][low_index:high_index] for param_idx, param in enumerate(group['params']): lean_state = self._get_state_without_padding(self.optimizer.state[param], param_paddings[param_idx]) group_lean_state.append(lean_state) optimizer_groups_state.append(group_lean_state) return optimizer_groups_state def state_dict(self): """ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict of the contained Pytorch optimizer. Example:: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, "saved.pth") """ state_dict = {} state_dict['loss_scaler'] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow state_dict['base_optimizer_state'] = self._get_base_optimizer_state() state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES state_dict['partition_count'] = self.partition_count state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group # Remove paddings for DP alignment to enable loading for other alignment values fp32_groups_without_padding = self._get_groups_without_padding( self.local_sub_partitions_of_fp32_groups) state_dict['local_sub_partitions_of_fp32_groups'] = fp32_groups_without_padding return state_dict def _retrieve_group_sub_partition_weights(self, all_partition_fp32_weights): partition_id = dist.get_rank(group=self.dp_process_group) all_sub_partition_weights = [] for partition_weights in all_partition_fp32_weights: for sub_partition_weights in partition_weights: all_sub_partition_weights.append(sub_partition_weights) flat_merged_weights = flatten_dense_tensors_sub_partition_aligned( tensor_list=all_sub_partition_weights, dp=dist.get_world_size(group=self.dp_process_group), max_elements_per_comm=self.max_elements_per_comm, pg=self.dp_process_group) comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \ self.get_data_parallel_sub_partitions( tensor=flat_merged_weights, max_elements_per_comm=self.max_elements_per_comm, world_size=dist.get_world_size(group=self.dp_process_group), dp_process_group=self.dp_process_group ) return [sub_partition for sub_partition in dp_sub_partitions[partition_id]] # Restore base optimizer fp32 weights from checkpoint by: # 1) Merging fp32 weights from checkpoints of all partitions # 2) Extracting fp32 weights for current partition from merged weights # 3) Using extracted weights to update base optimizer weights directly. def _restore_from_fp32_weights(self, all_state_dict): sub_partition_of_fp32_groups = [] for group_idx in range(len(self.local_sub_partitions_of_fp32_groups)): all_partition_fp32_weights = [ sd['local_sub_partitions_of_fp32_groups'][group_idx] for sd in all_state_dict ] sub_partition_weights = self._retrieve_group_sub_partition_weights( all_partition_fp32_weights) sub_partition_of_fp32_groups.append(sub_partition_weights) for current_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, sub_partition_of_fp32_groups): for current_sub_part, saved_sub_part in zip(current_group, saved_group): current_sub_part.data.copy_(saved_sub_part.data) # Extract optimizer state for current partition from merged states of all partitions def _partition_base_optimizer_state(self, state_key, all_partition_states): partition_id = dist.get_rank(group=self.dp_process_group) alignment = dist.get_world_size(group=self.dp_process_group) flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned( tensor_list=all_partition_states, dp=dist.get_world_size(group=self.dp_process_group), max_elements_per_comm=self.max_elements_per_comm, pg=self.dp_process_group) comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \ self.get_data_parallel_sub_partitions( tensor=flat_merged_partitions, max_elements_per_comm=self.max_elements_per_comm, world_size=dist.get_world_size(group=self.dp_process_group), dp_process_group=self.dp_process_group ) return [sub_partition for sub_partition in dp_sub_partitions[partition_id]] # Compute the optimizer state partitions for the group by # 1) Merging state values across the previous partitioning. # 2) Repartition state values for the new partitioning # 3) Return state corresponding to local partition def _retrieve_group_optimizer_states(self, all_partition_states): merged_optimizer_states = {} for partition_state in all_partition_states: for sub_partition_state in partition_state: for key, value in sub_partition_state.items(): if not key in merged_optimizer_states.keys(): merged_optimizer_states[key] = [value] else: merged_optimizer_states[key].append(value) group_optimizer_states = {} for key, value in merged_optimizer_states.items(): group_optimizer_states[key] = self._partition_base_optimizer_state( key, value) return group_optimizer_states # Restore base optimizer state from checkpoint by # 1) Merging optimizer state from checkpoints of all partitions # 2) Extracting optimizer state for current partition from the merged state # 3) Using the extracted value to directly update the base optimizer. def _restore_base_optimizer_state(self, state_dict_list): base_optimizer_group_states = [] for group_idx in range(len(self.optimizer.param_groups)): all_partition_group_states = [ sd['base_optimizer_state'][group_idx] for sd in state_dict_list ] group_optimizer_states = self._retrieve_group_optimizer_states( all_partition_group_states) base_optimizer_group_states.append(group_optimizer_states) for group_idx, group in enumerate(self.optimizer.param_groups): for param_idx, param in enumerate(group['params']): for key, saved in base_optimizer_group_states[group_idx].items(): current = self.optimizer.state[param][key] current.data.copy_(saved[param_idx].data) # Restore base optimizer fp32 weights from ZeRO fp16 weights def _restore_from_fp16_weights(self): partition_id = dist.get_rank(group=self.dp_process_group) for fp16_partitions, fp32_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups): for fp16_sub_partition, fp32_sub_partition in zip(fp16_partitions[partition_id], fp32_partitions): fp32_sub_partition.data.copy_(fp16_sub_partition.data) # Refresh the fp32 master params from the fp16 copies. def refresh_fp32_params(self): self._restore_from_fp16_weights() def load_state_dict(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False): """ Loads a state_dict created by an earlier call to state_dict(). If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, whose parameters in turn came from ``model``, it is expected that the user will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: model = torch.nn.Linear(D_in, D_out).cuda().half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... checkpoint = torch.load("saved.pth") model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) """ # I think it should actually be ok to reload the optimizer before the model. self.loss_scaler = state_dict_list[0]['loss_scaler'] self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale'] self.overflow = state_dict_list[0]['overflow'] if load_optimizer_states: self._restore_base_optimizer_state(state_dict_list) if load_from_fp32_weights: self._restore_from_fp32_weights(state_dict_list) else: self._restore_from_fp16_weights()