def create_and_check_bloom_weight_initialization(self, config, *args): model = BloomModel(config) model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layer) for key in model.state_dict().keys(): if "c_proj" in key and "weight" in key: self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001) self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
def convert_bloom_checkpoint_to_pytorch(bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp): # Construct model if bloom_config_file == "": config = BloomConfig() else: config = BloomConfig.from_json_file(bloom_config_file) if shard_model: file_names = os.listdir(bloom_checkpoint_path) file_names = list( sorted( filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))) index_dict = {"weight_map": {}, "metadata": {}} total_size = 0 missing_keys = None config = BloomConfig() for j, file in enumerate(file_names): print("Processing file: {}".format(file)) tensors = None for i in range(pretraining_tp): # load all TP files f_name = file.replace("model_00", f"model_0{i}") temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu") # Rename keys in the transformers names keys = list(temp.keys()) for key in keys: temp[layer_name_mapping(key, file)] = temp.pop(key) if tensors is None: tensors = temp else: for key in tensors.keys(): if any( key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH): # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425) tensors[key] += temp[key] else: # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel cat_dim = 1 if any( text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0 # We concatenate these weights accross TP ranks tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim) # Divide by the number of TP the weights we want to average for key in tensors.keys(): if any( key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH): tensors[key] = tensors[key] / pretraining_tp torch.save( tensors, os.path.join( pytorch_dump_folder_path, "pytorch_model_{}-of-{}.bin".format( str(j + 1).zfill(5), str(len(file_names)).zfill(5)), ), ) for key in tensors.keys(): value = tensors[key] total_size += value.numel() * get_dtype_size(value.dtype) if key not in index_dict["weight_map"]: index_dict["weight_map"][ key] = "pytorch_model_{}-of-{}.bin".format( str(j + 1).zfill(5), str(len(file_names)).zfill(5)) config = BloomConfig() pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME index_dict["metadata"]["total_size"] = total_size with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: f.write(config.to_json_string()) with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f: json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n" f.write(json_config) else: model = BloomModel(config) file_names = os.listdir(bloom_checkpoint_path) file_names = list( sorted( filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))) missing_keys = None for i, file in enumerate(file_names): tensors = None for i in range(pretraining_tp): # load all TP files f_name = file.replace("model_00", f"model_0{i}") temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu") # Rename keys in the transformers names keys = list(temp.keys()) for key in keys: temp[layer_name_mapping(key, file)] = temp.pop(key) if tensors is None: tensors = temp else: for key in tensors.keys(): # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425) if any( key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH): tensors[key] += temp[key] else: # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel cat_dim = 1 if any( text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0 # We concatenate these weights accross TP ranks tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim) # Divide by the number of TP the weights we want to average for key in tensors.keys(): if any( key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH): tensors[key] = tensors[key] / pretraining_tp other_keys = model.load_state_dict(tensors, strict=False) assert not other_keys.unexpected_keys if missing_keys is None: missing_keys = set(other_keys.missing_keys) else: missing_keys = missing_keys.intersection( set(other_keys.missing_keys)) assert not missing_keys # Save pytorch-model os.makedirs(pytorch_dump_folder_path, exist_ok=True) pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME print( f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}" ) model = model.to(config.torch_dtype) torch.save(model.state_dict(), pytorch_weights_dump_path) print(f"Save configuration file to {pytorch_config_dump_path}") with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: f.write(config.to_json_string())