Пример #1
0
def split_partition(model, partitions, tp_size, write_path=None):
    if len(partitions) != 1:
        raise ValueError(
            "Can only split partitions of model with TP=1. For partitions of models with TP>1, merge first."
        )

    if tp_size < 1:
        raise ValueError("TP size must to be >= 1.")

    app_state = AppState()
    app_state.data_parallel_rank = 0
    app_state.pipeline_model_parallel_size = 1  # not supported yet in this script
    app_state.tensor_model_parallel_size = tp_size
    app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size

    app_state.tensor_model_parallel_rank = tp_size - 1

    idx = 0
    splits = []
    for _, param in model.named_parameters():
        if param.shape == partitions[0][idx].shape:
            split = [partitions[0][idx].data] * tp_size
        elif param.shape[0] == partitions[0][idx].shape[0]:
            split = torch.split(partitions[0][idx].data,
                                param.shape[-1],
                                dim=-1)
        else:
            split = torch.split(partitions[0][idx].data, param.shape[0], dim=0)
        splits.append(split)
        idx += 1

    for i in range(tp_size - 1, -1, -1):
        app_state.tensor_model_parallel_rank = i

        idx = 0
        for name, param in model.named_parameters():
            split_val = splits[idx][i]

            if param.shape != split_val.shape:
                logging.info(
                    f"Warning: Shape mismatch for parameter {name} required shape: {param.shape}, split shape: {split_val.shape}. Padding to match required size."
                )

                if split_val.shape[1:] == param.shape[1:]:
                    pad = [0, 0] * len(split_val.shape)
                    pad[-1] = param.shape[0] - split_val.shape[0]
                    split_val = torch.nn.functional.pad(
                        split_val, pad, 'constant')
                elif split_val.shape[:-1] == param.shape[:-1]:
                    pad = [0, param.shape[-1] - split_val.shape[-1]]
                    split_val = torch.nn.functional.pad(
                        split_val, pad, 'constant')
                else:
                    raise RuntimeError(
                        f"Can not handle parameter {name}, required shape: {param.shape}, split shape: {split_val.shape}."
                    )

            param.data = split_val
            idx += 1

        if write_path is not None:
            model.save_to(write_path)
Пример #2
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--model_file",
                        type=str,
                        required=True,
                        help="Path to source .nemo file")
    parser.add_argument("--target_file",
                        type=str,
                        required=True,
                        help="Path to write target .nemo file")
    parser.add_argument("--tensor_model_parallel_size",
                        type=int,
                        required=True,
                        help="TP size of source model")
    parser.add_argument("--target_tensor_model_parallel_size",
                        type=int,
                        required=True,
                        help="TP size of target model")
    parser.add_argument(
        "--model_class",
        type=str,
        default=
        "nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel",
        help=
        "NeMo model class. This script should support all NeMo megatron models that use Tensor Parallel",
    )
    parser.add_argument("--precision",
                        default=16,
                        help="PyTorch Lightning Trainer precision flag")

    args = parser.parse_args()

    precision = args.precision
    if args.precision in ["32", "16"]:
        precision = int(float(args.precision))
    tp_size = args.tensor_model_parallel_size
    tgt_tp_size = args.target_tensor_model_parallel_size
    cls = model_utils.import_class_by_path(args.model_class)

    trainer = Trainer(devices=1,
                      plugins=NLPDDPPlugin(),
                      accelerator="cpu",
                      precision=precision)
    app_state = AppState()
    app_state.data_parallel_rank = 0
    app_state.pipeline_model_parallel_size = 1  # not supported yet in this script
    app_state.tensor_model_parallel_size = tp_size
    app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size

    if tp_size > 1:
        partitions = []
        for i in range(tp_size):
            app_state.tensor_model_parallel_rank = i
            model = cls.restore_from(restore_path=args.model_file,
                                     trainer=trainer,
                                     map_location=torch.device("cpu"))
            params = [p for _, p in model.named_parameters()]
            partitions.append(params)
            # app_state is being updated incorrectly during restore
            app_state.data_parallel_rank = 0
            app_state.pipeline_model_parallel_size = 1  # not supported yet in this script
            app_state.tensor_model_parallel_size = tp_size
            app_state.model_parallel_size = (
                app_state.pipeline_model_parallel_size *
                app_state.tensor_model_parallel_size)

        model.cfg.tensor_model_parallel_size = 1
        app_state.model_parallel_size = 1
        trainer = Trainer(devices=1,
                          plugins=NLPDDPPlugin(),
                          accelerator="cpu",
                          precision=precision)
        model = cls(model.cfg, trainer).to('cpu')
        model._save_restore_connector = NLPSaveRestoreConnector()

        if tgt_tp_size > 1:
            merge_partition(model, partitions)
        else:
            merge_partition(model, partitions, args.target_file)
    else:
        app_state.model_parallel_size = 1
        model = cls.restore_from(restore_path=args.model_file, trainer=trainer)

    if tgt_tp_size > 1:
        partitions = []
        params = [p for _, p in model.named_parameters()]
        partitions.append(params)

        model.cfg.tensor_model_parallel_size = tgt_tp_size
        app_state.model_parallel_size = tgt_tp_size
        trainer = Trainer(devices=1,
                          plugins=NLPDDPPlugin(),
                          accelerator="cpu",
                          precision=precision)
        model = cls(model.cfg, trainer).to('cpu')
        model._save_restore_connector = NLPSaveRestoreConnector()

        split_partition(model, partitions, tgt_tp_size, args.target_file)

    logging.info("Successfully finished changing partitions!")