def split_state_dict(self, mp_world_size, mp_rank, quantize=False, quantize_bits=8, groups=64, mlp_extra_grouping=True): #self.sanity_check(self.ckpt_list[0]) sd, num_to_split, ckpt_offset = self.get_split_state_dict( mp_world_size, mp_rank) ds_sd = copy.deepcopy(sd) new_client_sd = collections.OrderedDict() client_sd = self.get_module(sd) ckpt_ver = self.get_checkpoint_version(ds_sd) logger.info(f"checkpoint version: {ckpt_ver}") if quantize: quantizer = WeightQuantization( mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size) for key in client_sd.keys(): value = client_sd[key] if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key: assert value.shape[1] % num_to_split == 0 split_size = value.shape[1] // num_to_split if quantize: q_vals = quantizer.Quantize([value], quantize_bits, groups, key) value = q_vals[0] new_client_sd[key] = torch.split(value, split_size, dim=1)[ckpt_offset] elif "attention.query_key_value" in key: if quantize and "attention.query_key_value.weight" in key: q_vals = quantizer.Quantize([value], quantize_bits, groups, key) value = q_vals[0] new_client_sd[key] = self.split_query_key_value( value, num_to_split, ckpt_offset, ckpt_ver) elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key or "final_linear.weight" in key: assert value.shape[0] % num_to_split == 0 split_size = value.shape[0] // num_to_split if quantize and "mlp.dense_h_to_4h.weight" in key: q_vals = quantizer.Quantize([value], quantize_bits, groups, key) value = q_vals[0] new_client_sd[key] = torch.split(value, split_size, dim=0)[ckpt_offset] else: new_client_sd[key] = value if quantize: all_scales = quantizer.merge_scales_split(num_to_split) ds_sd = self.set_module(ds_sd, new_client_sd) return ds_sd, (all_scales if quantize else None)
def load_state_dir(self, load_dir, strict=True): for idx, layer in enumerate(self.forward_funcs): # Functions, etc. will not have state_dicts if not hasattr(layer, 'load_state_dict'): continue # get all checkpoint files for the layer. model_ckpt_list = self.ckpt_layer_path_list(load_dir, idx) mp_rank = self._grid.get_slice_parallel_rank() mp_world_size = self._grid.get_slice_parallel_world_size() sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, version=2.0) load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True) layer.load_state_dict(checkpoint) if self._grid.data_parallel_id == 0: logger.info( f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}' ) self._synchronize_tied_weights()
def save(self, state_dict, path: str): tag = _get_tag_from_path(path) partititon_name = os.path.basename(path) logger.info(f"[Nebula] Saving {partititon_name} under tag{tag}...") self.checkpoint.save(partititon_name, state_dict) logger.info(f"[Nebula] Saved {partititon_name} under tag{tag}.") return None
def __init__(self, gate: Module, experts: Module, num_local_experts: int, group: Optional[Any] = None, use_tutel: bool = False) -> None: super().__init__() self.gate = gate self.experts = experts self.group = group self.world_size = dist.get_world_size(group) self.num_local_experts = num_local_experts self.time_falltoall = 0.0 self.time_salltoall = 0.0 self.time_moe = 0.0 self.timers = SynchronizedWallClockTimer() self.wall_clock_breakdown = False self.use_tutel = use_tutel and TUTEL_INSTALLED if self.use_tutel: logger.info('Using Tutel optimizations.') elif use_tutel and not TUTEL_INSTALLED: logger.warning("Tutel optimization requested but not installed. " "Proceeding without Tutel.")
def override_loss_scale(self, loss_scale): if loss_scale != self.external_loss_scale: logger.info( f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}' ) self.custom_loss_scaler = True self.external_loss_scale = loss_scale
def merge_state_dict(self, mp_world_size, mp_rank, quantize=False, quantize_bits=8, groups=64, mlp_extra_grouping=True): self.sanity_check(self.ckpt_list[0]) sd_list = self.get_merge_state_dicts(mp_world_size, mp_rank) ds_sd = copy.deepcopy(sd_list[0]) new_client_sd = collections.OrderedDict() client_sd_list = [self.get_module(sd) for sd in sd_list] keys = client_sd_list[0].keys() ckpt_ver = self.get_checkpoint_version(ds_sd) logger.info(f"checkpoint version: {ckpt_ver}") if quantize: quantizer = WeightQuantization( mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size) for key in keys: value_list = [sd[key] for sd in client_sd_list] if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key: if quantize: value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key, merge_dim=1) new_client_sd[key] = torch.cat(value_list, axis=1) elif "attention.query_key_value" in key: if quantize and "attention.query_key_value.weight" in key: value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key) new_client_sd[key] = torch.cat(value_list, axis=0) else: if quantize: new_client_sd[key] = torch.cat(value_list, axis=0) else: new_client_sd[key] = self.merge_query_key_value( value_list, ckpt_ver) elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key: if quantize and "mlp.dense_h_to_4h.weight" in key: value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key) new_client_sd[key] = torch.cat(value_list, axis=0) else: new_client_sd[key] = value_list[0] if quantize: all_scales = quantizer.merge_scales() ds_sd = self.set_module(ds_sd, new_client_sd) return ds_sd, (all_scales if quantize else None), len(client_sd_list)
def __init__(self, gate: Module, experts: Module, ep_group_name, ep_size, num_local_experts: int, use_tutel: bool = False) -> None: super().__init__() self.gate = gate self.experts = experts self.ep_group = None self.ep_size = ep_size self.ep_group_name = ep_group_name self.num_local_experts = num_local_experts self.time_falltoall = 0.0 self.time_salltoall = 0.0 self.time_moe = 0.0 self.timers = SynchronizedWallClockTimer() self.wall_clock_breakdown = False self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1 if self.use_tutel: logger.info('Using Tutel optimizations.') elif use_tutel and not TUTEL_INSTALLED: logger.warning("Tutel optimization requested but not installed. " "Proceeding without Tutel.") elif use_tutel and TUTEL_INSTALLED and gate.k != 1: logger.warning( "To enable Tutel optimization, use top-1 instead of top-2 gate. " "Proceeding without Tutel.")
def load(self, mp_world_size, mp_rank, module_key=AUTO_MODULE_KEY, is_pipe_parallel=False, quantize=False, quantize_bits=8, quantize_groups=64, mlp_extra_grouping=True): self.module_key = module_key num_ckpt = len(self.ckpt_list) idx = mp_rank * num_ckpt // mp_world_size logger.info( f'mp_world_size: {mp_world_size}, mp_rank: {mp_rank}, module_key: {module_key}' ) """ We have multiple cases to handle here for both training and inference: 1. PipeModule loading mp_rank_*.pt files, is_pipe_parallel=True, module_key is not None a. if no mp_size/pp_size resizing occurs, for both training & inference, loading the mp_rank related checkpoint directly. b. if has mp_size/pp_size resizing, only Megatron model inference is supported, in this case each mp_rank_*.pt have same content, we will load the first checkpoint file (idx=0), to avoid idx exceeding file list boundary. 2. PipeModule loading layer_*.pt files, is_pipe_parallel=True, module_key is None a. if no mp_size resizing occurs, for both training & inference, loading the mp_rank related checkpoint directly. b. if has mp_size resizing, only Megatron model inference is supported, checkpoint file(s) will be merged/splitted according to mp_rank, mp_world_size and checkpoint file list. 3. Non-PipeModule loading mp_rank_*.pt files, is_pipe_parallel=False Same with case (2). """ if is_pipe_parallel and module_key is not None and mp_world_size != num_ckpt: mp_world_size = num_ckpt idx = 0 load_path = self.ckpt_list[idx] merge_count = 1 if num_ckpt == mp_world_size: assert os.path.exists(load_path) logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}') sd = torch.load(load_path, map_location=lambda storage, loc: storage) if quantize: quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size) sd_module, all_scales = quantizer.sd_quantize_megatron(self.get_module(sd), quantize_bits, quantize_groups) self.set_module(sd, sd_module) else: all_scales = None elif num_ckpt > mp_world_size: sd, all_scales, merge_count = self.merge_state_dict(mp_world_size, mp_rank, quantize, \ quantize_bits, quantize_groups, mlp_extra_grouping) else: sd, all_scales = self.split_state_dict(mp_world_size, mp_rank, quantize, quantize_bits, \ quantize_groups, mlp_extra_grouping) return load_path, sd, (all_scales, merge_count)
def tune(self, sample_size=1, n_trials=1000, early_stopping=None): i = 0 try: while i < n_trials and self.has_next(): # Select the next batch of configuratiosn for evaluation sampled_exps = self.next_batch(sample_size) # Generate experiments for measurement of performance exp_paths = write_experiments(sampled_exps, self.rm.exps_dir) self.rm.schedule_experiments(exp_paths) self.rm.run() exp, metric_val = self.rm.parse_results(self.metric) if self.best_exp == None or self.best_metric_val == None or ( metric_val and metric_val > self.best_metric_val): # logger.info(f"tuner finds better = {exp}") self.best_exp = exp self.best_metric_val = metric_val self.best_iter = i i += len(sampled_exps) # Update the tuner with evaluated performance results self.update() self.rm.clear() # Early stop if no more promising configurations are likely to be found if early_stopping and i >= self.best_iter + early_stopping: logger.info( f"Tuner early stopped at iteration {i}. Best iteration is {self.best_iter}. Early stopping threshold is {early_stopping}" ) break return i except: logger.info("Tunner Error:", sys.exc_info()[0]) return i
def step_fused_lamb(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] grads_groups = [] norm_groups = [] expert_norm_groups = [] for i, group in enumerate(self.fp16_groups): grads = [ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ] grads_groups.append(grads) grads_groups_flat.append(_flatten_dense_tensors(grads)) grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params( group) norm_group_value = 0.0 if len(grads_for_norm) > 0: norm_group_value = get_weight_norm( _flatten_dense_tensors(grads_for_norm), mpu=self.mpu) norm_groups.append(norm_group_value) expert_norm_group_value = 0.0 if len(expert_grads_for_norm) > 0: expert_norm_group_value = get_weight_norm( _flatten_dense_tensors(expert_grads_for_norm), mpu=self.mpu) expert_norm_groups.append(expert_norm_group_value) self.overflow = self.overflow_checker.check_using_norm( norm_groups + expert_norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False) self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale) for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def partition_activations_in_checkpoint(partition_activation): global PARTITION_ACTIVATIONS PARTITION_ACTIVATIONS = partition_activation if dist.get_rank() == 0: logger.info( f"**************Partition Activations {PARTITION_ACTIVATIONS}************" )
def _partition_layers(self, method='uniform'): num_stages = self._topo.get_dim('pipe') stage_id = self._topo.get_coord(self.global_rank).pipe if self.global_rank == 0: logger.info(f'Partitioning pipeline stages with method {method}') method = method.lower() # Each stage gets a simple uniform number of layers. if method == 'uniform': num_layers = len(self._layer_specs) self.parts = ds_utils.partition_uniform(num_items=num_layers, num_parts=num_stages) elif method == 'parameters': param_counts = self._count_layer_params() self.parts = ds_utils.partition_balanced(weights=param_counts, num_parts=num_stages) elif method.startswith('type:'): layertype = method.split(':')[1] binary_weights = [0] * len(self._layer_specs) for idx in self._find_layer_type(layertype): binary_weights[idx] = 1 else: self.parts = ds_utils.partition_balanced( weights=binary_weights, num_parts=num_stages) elif method == 'profile': raise NotImplementedError( f'Partitioning method {method} not implemented.') else: raise NotImplementedError( f'Partitioning method {method} not implemented.') # Print some information on the partitioning. if self.global_rank == 0: for stage in range(num_stages): start = self.parts[stage] stop = self.parts[stage + 1] print(f'stage={stage} layers={stop - start}') for idx, layer in enumerate(self._layer_specs[start:stop]): name = str(layer) if isinstance(layer, LayerSpec): name = layer.typename.__name__ if isinstance(layer, nn.Module): name = layer.__class__.__name__ else: try: name = layer.__name__ except AttributeError: pass print(f' {idx + start:2d}: {name}') if self.loss_fn: try: print(f' loss: {self.loss_fn.__name__}') except AttributeError: print(f' loss: {self.loss_fn.__class__.__name__}') self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])
def __init__(self, exps, resource_manager, metric): self.all_exps = exps self.rm = resource_manager self.best_iter = 0 self.best_exp = None self.best_metric_val = None self.metric = metric if metric else AUTOTUNING_METRIC_DEFAULT logger.info(f"total number of exps = {len(self.all_exps)}")
def check_ckpt_list(self): logger.info(f'checkpoint file list: {self.ckpt_list}') assert len(self.ckpt_list) > 0 sd = torch.load(self.ckpt_list[0], map_location=lambda storage, loc: storage) # check checkpoint count is same with saved mp_world_size if 'mp_world_size' in sd.keys(): assert len(self.ckpt_list) == sd['mp_world_size'], f"checkpoint count {len(self.ckpt_list)} is different from saved mp_world_size {sd['mp_world_size']}"
def step(self, closure=None): """ Not supporting closure. """ if self.fused_lamb_legacy: return self.step_fused_lamb() self.overflow = self.overflow_checker.check() prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_for_norm, _ = split_params_grads_into_shared_and_expert_params( group) norm_group_value = 0.0 if len(grads_for_norm) > 0: norm_group_value = get_weight_norm(grads_for_norm, mpu=self.mpu) norm_groups.append(norm_group_value) # copying gradients to fp32 to wor k with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self._global_grad_norm = get_global_norm(norm_list=norm_groups) self.unscale_and_clip_grads(self._global_grad_norm) self.optimizer.step() for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def _configure_using_config_file(deepspeed_config, mpu=None): global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config logger.info(config.repr()) PARTITION_ACTIVATIONS = config.partition_activations CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization num_layers = config.number_checkpoints PA_TO_CPU = config.cpu_checkpointing SYNCHRONIZE = config.synchronize_checkpoint_boundary PROFILE_TIME = config.profile
def _handle_overflow(cpu_sum, x, i): import math rank = torch.distributed.get_rank() if rank == 0: t_i = -1 for v_i, v in enumerate(x.data.contiguous().view(-1)): if not math.isfinite(float(v)): t_i = v_i break logger.info( f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" )
def commit(self, tag): # nebula commit will be call when all files under give tag are ready to be persisted in the async way. logger.info( f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting" ) commit_rls = self.checkpoint.commit() if not commit_rls: logger.error( f"[Nebula] failed to commit the checkpoint, please check the log." ) return False return commit_rls
def _set_batch_related_parameters(self): train_batch = self.train_batch_size micro_batch = self.train_micro_batch_size_per_gpu grad_acc = self.gradient_accumulation_steps #all values are provided nothing needs to be set if train_batch is not None and \ micro_batch is not None and \ grad_acc is not None: return #global_accumulation_steps needs to be set elif train_batch is not None and \ micro_batch is not None: grad_acc = train_batch // micro_batch grad_acc //= self.world_size self.gradient_accumulation_steps = grad_acc #micro_batch_per_gpu needs to be set elif train_batch is not None and \ grad_acc is not None: micro_batch = train_batch // self.world_size micro_batch //= grad_acc self.train_micro_batch_size_per_gpu = micro_batch #train_batch_size needs to be set elif micro_batch is not None and \ grad_acc is not None: train_batch_size = micro_batch * grad_acc train_batch_size *= self.world_size self.train_batch_size = train_batch_size #gradient_accumulation_steps and micro_batch_per_gpus is set elif train_batch is not None: self.gradient_accumulation_steps = 1 self.train_micro_batch_size_per_gpu = train_batch // self.world_size #train_batch_size and gradient_accumulation_step is set elif micro_batch is not None: self.train_batch_size = micro_batch * self.world_size self.gradient_accumulation_steps = 1 #either none of the three parameters are provided or just gradient_accumulation_step is provided else: assert False, \ 'Either train_batch_size or micro_batch_per_gpu needs to be provided' logger.info( f' After Train batch {self.train_batch_size} micro_batch {self.train_micro_batch_size_per_gpu} and grad_acc {self.gradient_accumulation_steps}' )
def see_memory_usage(message): return if torch.distributed.is_initialized( ) and not torch.distributed.get_rank() == 0: return # Print message except when distributed but not rank 0 logger.info(message) logger.info( f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \ Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB " )
def step_fused_adam(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ])) norm_groups.append( get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow scaled_grad_norm = get_global_norm(norm_list=norm_groups) combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False) # Stash unscaled gradient norm self._global_grad_norm = scaled_grad_norm / self.cur_scale # norm is in fact norm*cur_scale self.optimizer.step(grads=[[g] for g in grads_groups_flat], output_params=[[p] for p in self.fp16_groups_flat], scale=combined_scale, grad_norms=norm_groups) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow
def compute_quantization(self, input, index=0, factor=1): # fixing the quantization bits based on the training steps # when reducing 1 bit at each period, we increase the period # to go slowly toward the target quantization bits # the period and starting bit can be configured if input.start_bits != input.target_bits: if self.qsteps >= input.q_period: self.quantize_real_ratio = 1.0 input.q_period <<= 1 input.q_period *= factor input.start_bits -= 1 if self.q_verbose: logger.info( f'Quantization settings: current bit-precision = {input.start_bits}, step = {self.qsteps}, quantization period = {input.q_period}, index = {index}' ) assert (input.start_bits >= input.target_bits), \ 'Quantization bit is lower than target precision bits!' if self.use_quantizer_kernel: if input.start_bits <= 2: raise ValueError( 'Quantization bit is too low, please do it without quantization kernel!' ) input_q = ds_quantizer( input.data.clone(), self.q_groups, input.start_bits, asym=False if self.q_type == 'symmetric' else True, sr=False if self.q_rounding == 'nearest_neighbor' else True) else: if input.start_bits >= 3: input_flat = self.quantize_highbit(input.data, input.start_bits) elif input.start_bits == 2: assert self.q_type == 'symmetric', 'Quantization type is not symmetric!' assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!' input_flat = self.quantize_tenary(input.data) elif input.start_bits == 1: assert self.q_type == 'symmetric', 'Quantization type is not symmetric!' assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!' input_flat = self.quantize_binary(input.data) if self.use_quantizer_kernel: return self.mixed_fp16_quantize(input.data, input_q, index) else: if self.q_mixed_fp16 and input.start_bits >= input.target_bits - 1: input_flat = self.quantize_real_ratio * input.data + \ (1 - self.quantize_real_ratio) * input_flat return input_flat
def check_row_pruning(self): # check row pruning rp = self.different_compression_methods[ROW_PRUNING] if not rp[TECHNIQUE_ENABLED]: return else: shared_parameters = rp[SHARED_PARAMETERS] if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]: for module_name in module_name_list: module = recursive_getattr(self.model, module_name) module.row_pruning_enabled = True if not self.verbose[ROW_PRUNING]: logger.info(f'Row pruning is enabled at step {self.training_steps}') self.verbose[ROW_PRUNING] = True
def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count): group_paddings = [] flattened_size = sum([tensor.numel() for tensor in tensor_list]) for i in range(sub_partition_count): padding = get_alignment_padding(flattened_size, i, sub_partition_size) group_paddings.append(padding) logger.info("****Padding information*****") logger.info(f"tensor_size = {flattened_size}") logger.info(f"sub_partition_size = {sub_partition_size}") logger.info(f"sub_partition_count = {sub_partition_count}") for i, padding in enumerate(group_paddings): logger.info(f"padding[{i}] = {padding}") return group_paddings
def _initialize_parameter_parallel_groups(parameter_parallel_size=None): data_parallel_size = int(dist.get_world_size()) parameter_parallel_size = parameter_parallel_size or data_parallel_size logger.info("data_parallel_size: %s, parameter_parallel_size: %s", data_parallel_size, parameter_parallel_size) assert data_parallel_size % parameter_parallel_size == 0, \ 'world size should be divisible by parameter parallel size' rank = dist.get_rank() my_group = None for i in range(data_parallel_size // parameter_parallel_size): ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size) group = torch.distributed.new_group(ranks) if rank in ranks: my_group = group return my_group
def get_split_state_dict(self, mp_world_size, mp_rank): num_ckpt = len(self.ckpt_list) assert mp_world_size % num_ckpt == 0, 'Invalid checkpoints and world size for sd split' num_to_split = mp_world_size // num_ckpt ckpt_index = mp_rank // num_to_split ckpt_offset = mp_rank % num_to_split logger.info( f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}" ) sd = torch.load(self.ckpt_list[ckpt_index], map_location=lambda storage, loc: storage) return sd, num_to_split, ckpt_offset
def step(self, closure=None): """ Not supporting closure. """ if self.fused_lamb_legacy: return self.step_fused_lamb() self.overflow = self.overflow_checker.check() prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] OVERFLOW! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow norm_groups = [] for i, group in enumerate(self.fp16_groups): norm_groups.append(get_grad_norm(group, mpu=self.mpu)) # copying gradients to fp32 to work with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self.unscale_and_clip_grads(norm_groups) self.optimizer.step() for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for fp32_param, fp16_param in zip(fp32_group, fp16_group): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow
def get_merge_state_dicts(self, mp_world_size, mp_rank): num_ckpt = len(self.ckpt_list) assert num_ckpt % mp_world_size == 0, 'Invalid checkpoints and world size for sd merge' num_to_merge = num_ckpt // mp_world_size ckpt_list = [ self.ckpt_list[i] for i in range(num_to_merge * mp_rank, num_to_merge * (mp_rank + 1)) ] logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}") sd_list = [ torch.load(ckpt, map_location=lambda storage, loc: storage) for ckpt in ckpt_list ] return sd_list
def check_activation_quantization(self): # check activation quantization aq = self.different_compression_methods[ACTIVATION_QUANTIZATION] if not aq[TECHNIQUE_ENABLED]: return else: shared_parameters = aq[SHARED_PARAMETERS] if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]: for module_name in module_name_list: module = recursive_getattr(self.model, module_name) module.activation_quantization_enabled = True if not self.verbose[ACTIVATION_QUANTIZATION]: logger.info( f'Activation quantization is enabled at step {self.training_steps}' ) self.verbose[ACTIVATION_QUANTIZATION] = True
def load_state_dir(self, load_dir, strict=True): rank = dist.get_rank() layer_offset = self._local_start for idx, layer in enumerate(self.forward_funcs): # Functions, etc. will not have state_dicts if not hasattr(layer, 'load_state_dict'): continue model_ckpt_path = self.ckpt_layer_path(load_dir, idx) layer.load_state_dict(torch.load( model_ckpt_path, map_location=lambda storage, loc: storage), strict=strict) if self._grid.data_parallel_id == 0: logger.info( f'RANK={self.global_rank} Loaded layer={idx + layer_offset} file={model_ckpt_path}' ) self._synchronize_tied_weights()