Example #1
0
def save_fp16_optimizer(args, model, optimizer, partial=True):
    optimizer_state_dict = {}
    loss_scaler = optimizer.loss_scaler
    _model = loss_scaler.model
    loss_scaler.model = None
    _loss_scaler = copy.deepcopy(loss_scaler)
    loss_scaler.model = _model
    optimizer_state_dict["loss_scaler"] = _loss_scaler
    optimizer_state_dict["dynamic_loss_scale"] = optimizer.dynamic_loss_scale
    optimizer_state_dict["overflow"] = optimizer.overflow
    optimizer_state_dict[
        "first_closure_call_this_step"] = optimizer.first_closure_call_this_step
    cpu_fp32_from_fp16_groups = [[param.cpu() for param in group]
                                 for group in optimizer.fp32_from_fp16_groups]
    if optimizer.master_params_created:
        register_optimizer_hooks(model)
    if partial:
        optimizer_state_dict[
            "optimizer_state_dict"] = optimizer.local_state_dict()
        if args.shard_optimizer_state:
            if smp.rdp_rank() == 0:
                print(
                    "With shard_optimizer_state=True, gather full fp32_from_fp16_groups for the rdp_group on rdp rank 0"
                )
                gathered_cpu_fp32_from_fp16_groups = [
                    cpu_fp32_from_fp16_groups
                ]
                for src in range(1, smp.rdp_size()):
                    gathered_cpu_fp32_from_fp16_groups.append(
                        smp.recv_from(src, smp.RankType.RDP_RANK))
                optimizer_state_dict[
                    "fp32_from_fp16"] = gathered_cpu_fp32_from_fp16_groups
            else:
                smp.send(cpu_fp32_from_fp16_groups, 0, smp.RankType.RDP_RANK)
                optimizer_state_dict[
                    "fp32_from_fp16"] = cpu_fp32_from_fp16_groups
        else:
            optimizer_state_dict["fp32_from_fp16"] = cpu_fp32_from_fp16_groups
        if smp.pp_size() > 1:
            print(
                "WARNING: Ensure that partition decision doesnt change between runs (you can ensure this by setting use_times=False in smp config)."
                "If you want to save and load with partition decision changing between runs, use full save and load instead."
            )
    else:
        optimizer_state_dict["optimizer_state_dict"] = optimizer.state_dict()
        if smp.tp_size() > 1 and not args.shard_optimizer_state:
            tp_merged_fp32_from_fp16_groups, param_name_groups = get_tp_merged_fp32_from_fp16_param_groups(
                optimizer, cpu_fp32_from_fp16_groups)
            pp_merged_fp32_from_fp16_groups, param_name_groups = get_pp_merged_fp32_from_fp16_param_groups(
                optimizer, tp_merged_fp32_from_fp16_groups, param_name_groups)
        else:
            raise ValueError(
                "Loading full optimizer state is not supported, when TP is not enabled or shard_optimizer_state is enabled"
            )
        optimizer_state_dict[
            "fp32_from_fp16"] = pp_merged_fp32_from_fp16_groups
        optimizer_state_dict["param_name_groups"] = param_name_groups
    return optimizer_state_dict
def create_pretraining_dataloader(
    input_paths: List[str],
    batch_size: int,
    max_sequence_length: int,
    seed: int,
    dp_rank: int,
    dp_size: int,
    shuffle: bool = False,
    zipped: bool = True,
    use_last_file_only: bool = False,
    data_type: str = "GPT",
):
    if smp.pp_rank() == 0:
        if data_type == "GPT":
            data = GPTPretrainingDataset(
                input_paths=input_paths, max_sequence_length=max_sequence_length, zipped=zipped, use_last_file_only=use_last_file_only
            )
        elif data_type == "BERT":
            if len(input_paths) > 1:
                print(f"BERT data only support single file when calling create_pretraining_dataloader, reading the first file instead..")
            data = BertPretrainingDataset(input_file=input_paths[0], max_pred_length=max_sequence_length)
        else:
            raise ValueError(f"Unsupported data type {data_type}")
        # TODO: set sampler.epoch to correctly shuffle across epochs, else same order will be used for all epochs
        # not relevant now as we have no epochs
        sampler = torch.utils.data.DistributedSampler(
            data,
            shuffle=shuffle,
            seed=seed,
            rank=dp_rank,
            num_replicas=dp_size,
            drop_last=True,
        )
        dataloader = torch.utils.data.DataLoader(
            data,
            sampler=sampler,
            batch_size=batch_size,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )
        smp.broadcast(len(dataloader), smp.PP_GROUP)
    else:
        data_len = smp.recv_from(0, smp.RankType.PP_RANK)
        dataset = DummyDataset(data_len * batch_size, data_type=data_type)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True)

    return dataloader
def create_pretraining_dataloader(
    input_paths: List[str],
    batch_size: int,
    max_sequence_length: int,
    seed: int,
    dp_rank: int,
    dp_size: int,
    shuffle: bool = False,
    zipped: bool = True,
    use_last_file_only: bool = False,
    data_type: str = "openwebtext",
):
    if smp.pp_rank() == 0:
        if data_type == "openwebtext":
            data = OpenwebtextPretrainingDataset(
                input_paths=input_paths, max_sequence_length=max_sequence_length, zipped=zipped, use_last_file_only=use_last_file_only
            )
        elif data_type == "wiki":
            if len(input_paths) > 1:
                print(f"Wiki data only support single file when calling create_pretraining_dataloader, reading the first file instead..")
            data = WikiPretrainingDataset(input_file=input_paths[0], max_pred_length=max_sequence_length)
        else:
            raise ValueError(f"Unsupported data type {data_type}")
        sampler = torch.utils.data.DistributedSampler(
            data,
            shuffle=shuffle,
            seed=seed,
            rank=dp_rank,
            num_replicas=dp_size,
            drop_last=True,
        )
        dataloader = torch.utils.data.DataLoader(
            data,
            sampler=sampler,
            batch_size=batch_size,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )
        smp.broadcast(len(dataloader), smp.PP_GROUP)
    else:
        data_len = smp.recv_from(0, smp.RankType.PP_RANK)
        dataset = DummyDataset(data_len * batch_size, data_type=data_type)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True)

    return dataloader
Example #4
0
def get_tp_merged_fp32_from_fp16_param_groups(optimizer,
                                              cpu_fp32_from_fp16_groups):
    def _merge_param_group_tp_group(group_idx, param_group):
        result_fp32_from_fp16_param_group = []
        param_name_group = {}
        for i, param in enumerate(param_group):
            # for each param, obtain param_name from param using two dicts above for tp_rank 0
            param_index = param_id_to_index_tp_group[rank_0][
                fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]]
            param_name = param_index_to_name_tp_group[rank_0][param_index]
            # obtain distribution axis for the param and check if its distributed
            # axis = master_distribution_axis_tp_rank_0[fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]]
            axis = master_distribution_axis_tp_rank_0.get(
                fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i],
                None)
            if axis is not None:
                tensors = []
                for r in range(smp.tp_size()):
                    # if distributed, for each rank, obtain param id from index using above two dicts
                    param_index_r = param_name_to_index_tp_group[r][param_name]
                    param_id_r = param_index_to_id_tp_group[r][param_index_r]

                    # search param id in fp32_from_fp16_groups_param_ids and find the index.
                    group_param_idx = fp32_from_fp16_paramid_groups_tp_group[
                        r][group_idx].index(param_id_r)
                    # use the param corresponding to the index from fp32_from_fp16_groups for concatenation along axis
                    tensors.append(fp32_from_fp16_param_groups_tp_group[r]
                                   [group_idx][group_param_idx])
                result_fp32_from_fp16_param_group.append(
                    torch.cat(tensors, axis))
            else:
                # if not distributed set tp_rank 0 param as the param
                result_fp32_from_fp16_param_group.append(param)
            param_name_group[param_name] = i
        return result_fp32_from_fp16_param_group, param_name_group

    # get param_index_to_name all and param_name_to_index_all
    param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group
    param_name_to_index_tp_group = smp_state.param_name_to_index_tp_group
    # get mapping of param_id_to_index_all and param_index_to_id_all
    param_id_to_index = optimizer._param_id_to_index()
    param_id_to_index_tp_group = smp.allgather(param_id_to_index, smp.TP_GROUP)
    param_index_to_id_tp_group = _get_param_index_to_id(
        param_id_to_index_tp_group)
    # allgather all param ids and all params for fp32_from_fp16_groups
    fp32_from_fp16_paramid_groups = optimizer.fp32_from_fp16_paramid_groups
    fp32_from_fp16_paramid_groups_tp_group = smp.allgather(
        fp32_from_fp16_paramid_groups, smp.TP_GROUP)
    fp32_from_fp16_param_groups_tp_group = smp.allgather(
        cpu_fp32_from_fp16_groups, smp.TP_GROUP)
    # broadcast distribution axis from tp_rank 0 to all tp_ranks
    master_distribution_axis_tp_rank_0 = None
    if smp.tp_rank() == 0:
        master_distribution_axis_tp_rank_0 = optimizer.master_distribution_axis
        smp.broadcast(master_distribution_axis_tp_rank_0, smp.TP_GROUP)
    else:
        master_distribution_axis_tp_rank_0 = smp.recv_from(
            0, smp.RankType.TP_RANK)

    result_fp32_from_fp16_param_groups = []
    param_name_groups = []
    rank_0 = 0
    # iterate through all the params for tp_group_fp32_from_fp16_groups[rank_0]
    for group_idx, param_group in enumerate(
            fp32_from_fp16_param_groups_tp_group[rank_0]):
        result_fp32_from_fp16_param_group, param_name_group = _merge_param_group_tp_group(
            group_idx, param_group)
        result_fp32_from_fp16_param_groups.append(
            result_fp32_from_fp16_param_group)
        param_name_groups.append(param_name_group)
    return result_fp32_from_fp16_param_groups, param_name_groups