Пример #1
0
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint with expert parallel """
    # TODO: update patch
    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last

    expert_dp_comm = "none"

    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follows the native load_checkpoint by megatron
        from megatron.checkpointing import save_checkpoint as save_checkpoint_native

        save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
        return

    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, DistributedDataParallel):
        model = model.module

    print_rank_last("saving checkpoint at iteration {:7d} to {}".format(
        iteration, args.save))

    # Arguments, iteration, and model.
    state_dict = {}
    state_dict["model"] = model.state_dict_for_save_checkpoint(
        keep_vars=(mpu.get_data_parallel_rank() > 0))

    def extract_expert_param(state_dict, expert_dp_comm="none"):
        state_dict_new = state_dict.__class__()
        for k, v in state_dict.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                v_new = extract_expert_param(v, expert_dp_comm)
                if len(v_new) > 0:
                    state_dict_new[k] = v_new
            elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm:
                state_dict_new[k] = v.detach()
        return state_dict_new

    state_dict["model"] = extract_expert_param(state_dict["model"],
                                               expert_dp_comm)

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
            state_dict["optimizer"] = optimizer.state_dict()
            param_global_idx = 0
            for param_group in optimizer.optimizer.param_groups:
                for param in param_group["params"]:
                    if not (hasattr(param, "dp_comm")
                            and param.dp_comm == expert_dp_comm):
                        # this parameter is not an expert parameter
                        # thus there is no need to save its state in current rank
                        # since it has been saved by data parallel rank 0
                        if args.fp16:
                            # fp16 optimizer may have empty state due to overflow
                            state_dict["optimizer"]["optimizer"]["state"].pop(
                                param_global_idx, None)
                        else:
                            state_dict["optimizer"]["state"].pop(
                                param_global_idx)
                    param_global_idx += 1
            if args.fp16:
                state_dict["optimizer"]["optimizer"].pop("param_groups")
                # fp32_from_fp16_params in state_dict is not a copy
                # but a reference to optimizer.fp32_from_fp16_params,
                # changing it in state_dict will change
                # optimizer.fp32_from_fp16_params as well
                # thus we create an empty fp32_from_fp16_params in state_dict
                # and only insert expert parameters.
                fp32_from_fp16_params = state_dict["optimizer"][
                    "fp32_from_fp16_params"]
                state_dict["optimizer"]["fp32_from_fp16_params"] = []
                for param_group in fp32_from_fp16_params:
                    param_group_copy = []
                    for param in param_group:
                        param_copy = (param if hasattr(param, "dp_comm")
                                      and param.dp_comm == expert_dp_comm else
                                      None)
                        param_group_copy.append(param_copy)
                    state_dict["optimizer"]["fp32_from_fp16_params"].append(
                        param_group_copy)
            else:
                state_dict["optimizer"].pop("param_groups")

    # Save.
    checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename

    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(
            "  successfully saved checkpoint at iteration {:7d} to {}".format(
                iteration, args.save),
            flush=True,
        )
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, "w") as f:
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()
Пример #2
0
def main():

    # Args
    args = _parse_args(extra_args_provider=get_mp_merge_args)
    model_type = args.model_type
    orig_model_parallel_size = args.model_parallel_size
    args.model_parallel_size = 1
    tokenizer = rebuild_tokenizer(args)

    print('\n merging model parallel partitions ...')
    print(' > number of partitions: {}'.format(orig_model_parallel_size))
    print(' > checkpoint path: {}'.format(args.load))
    print(' > model parameters:')
    print('    number of tokens ................ {} '.format(
        tokenizer.vocab_size))
    print('    number of layers ................ {}'.format(args.num_layers))
    print('    hidden sise ..................... {}'.format(args.hidden_size))
    print('    number of attention heads ....... {}'.format(
        args.num_attention_heads))
    print('    maximum position embeddings ..... {}'.format(
        args.max_position_embeddings))

    # Full model.
    print('> building the full model ...')
    mpu.initialize.set_model_parallel_world_size(1)
    mpu.initialize.set_model_parallel_rank(0)
    merged_model = get_model(model_type)

    # Build and load partitions.
    partitions = []
    iteration = 0
    args.model_parallel_size = orig_model_parallel_size
    tokenizer = rebuild_tokenizer(args)
    mpu.initialize.set_model_parallel_world_size(args.model_parallel_size)
    for rank in range(args.model_parallel_size):
        mpu.initialize.set_model_parallel_rank(rank)
        checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
        print('> loading {} ...'.format(checkpoint_name))
        model_ = get_model(model_type)
        sd = torch.load(checkpoint_name, map_location='cpu')
        model_.load_state_dict(sd['model'])
        partitions.append(model_)

    # Parameter generators so we can loop through them semiltaneouly.
    merged_params_gen = merged_model.named_parameters()
    partitions_params_gen = [
        partition.named_parameters() for partition in partitions
    ]
    while True:
        try:

            # Get the params and check names.
            name, merged_param = next(merged_params_gen)
            print(' > working on {} ...'.format(name))
            print('     merged         type: {}, size: {}'.format(
                merged_param.dtype, list(merged_param.size())))
            partitions_param = []
            for rank, partition_params_gen in enumerate(partitions_params_gen):
                partition_name, partition_param = next(partition_params_gen)
                assert partition_name == name
                partitions_param.append(partition_param)
                print('     partition {}    type: {}, size: {}'.format(
                    rank, partition_param.dtype, list(partition_param.size())))

            # For the non-parallel parameters, simply copy the rank 0 values.
            if not hasattr(merged_param, 'model_parallel'):
                print('     none-parallel parameter, simple copy from rank 0')
                with torch.no_grad():
                    merged_param.data.copy_(partitions_param[0].data)
            # For parallel parameters, merge the values
            else:
                print('     parallel parameter merge with stride {} along '
                      'dimention {}'.format(merged_param.stride,
                                            merged_param.partition_dim))
                merge_partitions(merged_param, partitions_param,
                                 merged_param.partition_dim,
                                 merged_param.stride)

        except StopIteration:
            break

    # Save the model.
    args.model_parallel_size = 1
    mpu.initialize.set_model_parallel_rank(0)
    sd = {}
    sd['model'] = merged_model.state_dict_for_save_checkpoint()
    sd['iteration'] = iteration
    merged_path = os.path.join(args.load, 'merged')
    checkpoint_name = get_checkpoint_name(merged_path, iteration)
    ensure_directory_exists(checkpoint_name)
    print('> saving merged model to {}'.format(checkpoint_name))
    torch.save(sd, checkpoint_name)

    print('done :-)')