def save_fp16_optimizer(args, model, optimizer, partial=True): optimizer_state_dict = {} loss_scaler = optimizer.loss_scaler _model = loss_scaler.model loss_scaler.model = None _loss_scaler = copy.deepcopy(loss_scaler) loss_scaler.model = _model optimizer_state_dict["loss_scaler"] = _loss_scaler optimizer_state_dict["dynamic_loss_scale"] = optimizer.dynamic_loss_scale optimizer_state_dict["overflow"] = optimizer.overflow optimizer_state_dict[ "first_closure_call_this_step"] = optimizer.first_closure_call_this_step cpu_fp32_from_fp16_groups = [[param.cpu() for param in group] for group in optimizer.fp32_from_fp16_groups] if optimizer.master_params_created: register_optimizer_hooks(model) if partial: optimizer_state_dict[ "optimizer_state_dict"] = optimizer.local_state_dict() if args.shard_optimizer_state: if smp.rdp_rank() == 0: print( "With shard_optimizer_state=True, gather full fp32_from_fp16_groups for the rdp_group on rdp rank 0" ) gathered_cpu_fp32_from_fp16_groups = [ cpu_fp32_from_fp16_groups ] for src in range(1, smp.rdp_size()): gathered_cpu_fp32_from_fp16_groups.append( smp.recv_from(src, smp.RankType.RDP_RANK)) optimizer_state_dict[ "fp32_from_fp16"] = gathered_cpu_fp32_from_fp16_groups else: smp.send(cpu_fp32_from_fp16_groups, 0, smp.RankType.RDP_RANK) optimizer_state_dict[ "fp32_from_fp16"] = cpu_fp32_from_fp16_groups else: optimizer_state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups if smp.pp_size() > 1: print( "WARNING: Ensure that partition decision doesnt change between runs (you can ensure this by setting use_times=False in smp config)." "If you want to save and load with partition decision changing between runs, use full save and load instead." ) else: optimizer_state_dict["optimizer_state_dict"] = optimizer.state_dict() if smp.tp_size() > 1 and not args.shard_optimizer_state: tp_merged_fp32_from_fp16_groups, param_name_groups = get_tp_merged_fp32_from_fp16_param_groups( optimizer, cpu_fp32_from_fp16_groups) pp_merged_fp32_from_fp16_groups, param_name_groups = get_pp_merged_fp32_from_fp16_param_groups( optimizer, tp_merged_fp32_from_fp16_groups, param_name_groups) else: raise ValueError( "Loading full optimizer state is not supported, when TP is not enabled or shard_optimizer_state is enabled" ) optimizer_state_dict[ "fp32_from_fp16"] = pp_merged_fp32_from_fp16_groups optimizer_state_dict["param_name_groups"] = param_name_groups return optimizer_state_dict
def create_pretraining_dataloader( input_paths: List[str], batch_size: int, max_sequence_length: int, seed: int, dp_rank: int, dp_size: int, shuffle: bool = False, zipped: bool = True, use_last_file_only: bool = False, data_type: str = "GPT", ): if smp.pp_rank() == 0: if data_type == "GPT": data = GPTPretrainingDataset( input_paths=input_paths, max_sequence_length=max_sequence_length, zipped=zipped, use_last_file_only=use_last_file_only ) elif data_type == "BERT": if len(input_paths) > 1: print(f"BERT data only support single file when calling create_pretraining_dataloader, reading the first file instead..") data = BertPretrainingDataset(input_file=input_paths[0], max_pred_length=max_sequence_length) else: raise ValueError(f"Unsupported data type {data_type}") # TODO: set sampler.epoch to correctly shuffle across epochs, else same order will be used for all epochs # not relevant now as we have no epochs sampler = torch.utils.data.DistributedSampler( data, shuffle=shuffle, seed=seed, rank=dp_rank, num_replicas=dp_size, drop_last=True, ) dataloader = torch.utils.data.DataLoader( data, sampler=sampler, batch_size=batch_size, num_workers=0, pin_memory=True, drop_last=True, ) smp.broadcast(len(dataloader), smp.PP_GROUP) else: data_len = smp.recv_from(0, smp.RankType.PP_RANK) dataset = DummyDataset(data_len * batch_size, data_type=data_type) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True) return dataloader
def create_pretraining_dataloader( input_paths: List[str], batch_size: int, max_sequence_length: int, seed: int, dp_rank: int, dp_size: int, shuffle: bool = False, zipped: bool = True, use_last_file_only: bool = False, data_type: str = "openwebtext", ): if smp.pp_rank() == 0: if data_type == "openwebtext": data = OpenwebtextPretrainingDataset( input_paths=input_paths, max_sequence_length=max_sequence_length, zipped=zipped, use_last_file_only=use_last_file_only ) elif data_type == "wiki": if len(input_paths) > 1: print(f"Wiki data only support single file when calling create_pretraining_dataloader, reading the first file instead..") data = WikiPretrainingDataset(input_file=input_paths[0], max_pred_length=max_sequence_length) else: raise ValueError(f"Unsupported data type {data_type}") sampler = torch.utils.data.DistributedSampler( data, shuffle=shuffle, seed=seed, rank=dp_rank, num_replicas=dp_size, drop_last=True, ) dataloader = torch.utils.data.DataLoader( data, sampler=sampler, batch_size=batch_size, num_workers=0, pin_memory=True, drop_last=True, ) smp.broadcast(len(dataloader), smp.PP_GROUP) else: data_len = smp.recv_from(0, smp.RankType.PP_RANK) dataset = DummyDataset(data_len * batch_size, data_type=data_type) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True) return dataloader
def get_tp_merged_fp32_from_fp16_param_groups(optimizer, cpu_fp32_from_fp16_groups): def _merge_param_group_tp_group(group_idx, param_group): result_fp32_from_fp16_param_group = [] param_name_group = {} for i, param in enumerate(param_group): # for each param, obtain param_name from param using two dicts above for tp_rank 0 param_index = param_id_to_index_tp_group[rank_0][ fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]] param_name = param_index_to_name_tp_group[rank_0][param_index] # obtain distribution axis for the param and check if its distributed # axis = master_distribution_axis_tp_rank_0[fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]] axis = master_distribution_axis_tp_rank_0.get( fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i], None) if axis is not None: tensors = [] for r in range(smp.tp_size()): # if distributed, for each rank, obtain param id from index using above two dicts param_index_r = param_name_to_index_tp_group[r][param_name] param_id_r = param_index_to_id_tp_group[r][param_index_r] # search param id in fp32_from_fp16_groups_param_ids and find the index. group_param_idx = fp32_from_fp16_paramid_groups_tp_group[ r][group_idx].index(param_id_r) # use the param corresponding to the index from fp32_from_fp16_groups for concatenation along axis tensors.append(fp32_from_fp16_param_groups_tp_group[r] [group_idx][group_param_idx]) result_fp32_from_fp16_param_group.append( torch.cat(tensors, axis)) else: # if not distributed set tp_rank 0 param as the param result_fp32_from_fp16_param_group.append(param) param_name_group[param_name] = i return result_fp32_from_fp16_param_group, param_name_group # get param_index_to_name all and param_name_to_index_all param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group param_name_to_index_tp_group = smp_state.param_name_to_index_tp_group # get mapping of param_id_to_index_all and param_index_to_id_all param_id_to_index = optimizer._param_id_to_index() param_id_to_index_tp_group = smp.allgather(param_id_to_index, smp.TP_GROUP) param_index_to_id_tp_group = _get_param_index_to_id( param_id_to_index_tp_group) # allgather all param ids and all params for fp32_from_fp16_groups fp32_from_fp16_paramid_groups = optimizer.fp32_from_fp16_paramid_groups fp32_from_fp16_paramid_groups_tp_group = smp.allgather( fp32_from_fp16_paramid_groups, smp.TP_GROUP) fp32_from_fp16_param_groups_tp_group = smp.allgather( cpu_fp32_from_fp16_groups, smp.TP_GROUP) # broadcast distribution axis from tp_rank 0 to all tp_ranks master_distribution_axis_tp_rank_0 = None if smp.tp_rank() == 0: master_distribution_axis_tp_rank_0 = optimizer.master_distribution_axis smp.broadcast(master_distribution_axis_tp_rank_0, smp.TP_GROUP) else: master_distribution_axis_tp_rank_0 = smp.recv_from( 0, smp.RankType.TP_RANK) result_fp32_from_fp16_param_groups = [] param_name_groups = [] rank_0 = 0 # iterate through all the params for tp_group_fp32_from_fp16_groups[rank_0] for group_idx, param_group in enumerate( fp32_from_fp16_param_groups_tp_group[rank_0]): result_fp32_from_fp16_param_group, param_name_group = _merge_param_group_tp_group( group_idx, param_group) result_fp32_from_fp16_param_groups.append( result_fp32_from_fp16_param_group) param_name_groups.append(param_name_group) return result_fp32_from_fp16_param_groups, param_name_groups