def split_partition(model, partitions, tp_size, write_path=None): if len(partitions) != 1: raise ValueError( "Can only split partitions of model with TP=1. For partitions of models with TP>1, merge first." ) if tp_size < 1: raise ValueError("TP size must to be >= 1.") app_state = AppState() app_state.data_parallel_rank = 0 app_state.pipeline_model_parallel_size = 1 # not supported yet in this script app_state.tensor_model_parallel_size = tp_size app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size app_state.tensor_model_parallel_rank = tp_size - 1 idx = 0 splits = [] for _, param in model.named_parameters(): if param.shape == partitions[0][idx].shape: split = [partitions[0][idx].data] * tp_size elif param.shape[0] == partitions[0][idx].shape[0]: split = torch.split(partitions[0][idx].data, param.shape[-1], dim=-1) else: split = torch.split(partitions[0][idx].data, param.shape[0], dim=0) splits.append(split) idx += 1 for i in range(tp_size - 1, -1, -1): app_state.tensor_model_parallel_rank = i idx = 0 for name, param in model.named_parameters(): split_val = splits[idx][i] if param.shape != split_val.shape: logging.info( f"Warning: Shape mismatch for parameter {name} required shape: {param.shape}, split shape: {split_val.shape}. Padding to match required size." ) if split_val.shape[1:] == param.shape[1:]: pad = [0, 0] * len(split_val.shape) pad[-1] = param.shape[0] - split_val.shape[0] split_val = torch.nn.functional.pad( split_val, pad, 'constant') elif split_val.shape[:-1] == param.shape[:-1]: pad = [0, param.shape[-1] - split_val.shape[-1]] split_val = torch.nn.functional.pad( split_val, pad, 'constant') else: raise RuntimeError( f"Can not handle parameter {name}, required shape: {param.shape}, split shape: {split_val.shape}." ) param.data = split_val idx += 1 if write_path is not None: model.save_to(write_path)
def main(): parser = ArgumentParser() parser.add_argument("--model_file", type=str, required=True, help="Path to source .nemo file") parser.add_argument("--target_file", type=str, required=True, help="Path to write target .nemo file") parser.add_argument("--tensor_model_parallel_size", type=int, required=True, help="TP size of source model") parser.add_argument("--target_tensor_model_parallel_size", type=int, required=True, help="TP size of target model") parser.add_argument( "--model_class", type=str, default= "nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel", help= "NeMo model class. This script should support all NeMo megatron models that use Tensor Parallel", ) parser.add_argument("--precision", default=16, help="PyTorch Lightning Trainer precision flag") args = parser.parse_args() precision = args.precision if args.precision in ["32", "16"]: precision = int(float(args.precision)) tp_size = args.tensor_model_parallel_size tgt_tp_size = args.target_tensor_model_parallel_size cls = model_utils.import_class_by_path(args.model_class) trainer = Trainer(devices=1, plugins=NLPDDPPlugin(), accelerator="cpu", precision=precision) app_state = AppState() app_state.data_parallel_rank = 0 app_state.pipeline_model_parallel_size = 1 # not supported yet in this script app_state.tensor_model_parallel_size = tp_size app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size if tp_size > 1: partitions = [] for i in range(tp_size): app_state.tensor_model_parallel_rank = i model = cls.restore_from(restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu")) params = [p for _, p in model.named_parameters()] partitions.append(params) # app_state is being updated incorrectly during restore app_state.data_parallel_rank = 0 app_state.pipeline_model_parallel_size = 1 # not supported yet in this script app_state.tensor_model_parallel_size = tp_size app_state.model_parallel_size = ( app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size) model.cfg.tensor_model_parallel_size = 1 app_state.model_parallel_size = 1 trainer = Trainer(devices=1, plugins=NLPDDPPlugin(), accelerator="cpu", precision=precision) model = cls(model.cfg, trainer).to('cpu') model._save_restore_connector = NLPSaveRestoreConnector() if tgt_tp_size > 1: merge_partition(model, partitions) else: merge_partition(model, partitions, args.target_file) else: app_state.model_parallel_size = 1 model = cls.restore_from(restore_path=args.model_file, trainer=trainer) if tgt_tp_size > 1: partitions = [] params = [p for _, p in model.named_parameters()] partitions.append(params) model.cfg.tensor_model_parallel_size = tgt_tp_size app_state.model_parallel_size = tgt_tp_size trainer = Trainer(devices=1, plugins=NLPDDPPlugin(), accelerator="cpu", precision=precision) model = cls(model.cfg, trainer).to('cpu') model._save_restore_connector = NLPSaveRestoreConnector() split_partition(model, partitions, tgt_tp_size, args.target_file) logging.info("Successfully finished changing partitions!")