def create_and_check_bloom_weight_initialization(self, config, *args):
     model = BloomModel(config)
     model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layer)
     for key in model.state_dict().keys():
         if "c_proj" in key and "weight" in key:
             self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
             self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
    def create_and_check_bloom_model_past_large_inputs(self, config, input_ids, input_mask, *args):
        model = BloomModel(config=config)
        model.to(torch_device)
        model.eval()

        # first forward pass
        outputs = model(input_ids, attention_mask=input_mask, use_cache=True)

        output, past = outputs.to_tuple()

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
        next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)

        # append to next input_ids and token_type_ids
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
        next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)

        output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
        output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past)[
            "last_hidden_state"
        ]
        self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
    def create_and_check_bloom_model_past(self, config, input_ids, input_mask, *args):
        model = BloomModel(config=config)

        model.to(torch_device)
        model.eval()

        # first forward pass
        outputs = model(input_ids, attention_mask=torch.ones_like(input_ids), use_cache=True)
        outputs_use_cache_conf = model(input_ids, attention_mask=torch.ones_like(input_ids))
        outputs_no_past = model(input_ids, use_cache=False, attention_mask=torch.ones_like(input_ids))

        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)

        past = outputs["past_key_values"]

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

        # append to next input_ids and token_type_ids
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)

        output_from_no_past = model(next_input_ids)["last_hidden_state"]
        output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
    def create_and_check_bloom_model_attention_mask_past(
            self, config, input_ids, input_mask, *args):
        model = BloomModel(config=config)
        model.to(torch_device)
        model.eval()

        # create attention mask
        attn_mask = torch.ones(input_ids.shape,
                               dtype=torch.long,
                               device=torch_device)
        half_seq_length = self.seq_length // 2
        attn_mask[:, half_seq_length:] = 0

        # first forward pass
        output, past = model(input_ids, attention_mask=attn_mask).to_tuple()

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

        # change a random masked slice from input_ids
        random_seq_idx_to_change = ids_tensor(
            (1, ), half_seq_length).item() + 1
        random_other_next_tokens = ids_tensor((self.batch_size, 1),
                                              config.vocab_size).squeeze(-1)
        input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens

        # append to next input_ids and attn_mask
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
        attn_mask = torch.cat(
            [
                attn_mask,
                torch.ones((attn_mask.shape[0], 1),
                           dtype=torch.long,
                           device=torch_device)
            ],
            dim=1,
        )

        # get two different outputs
        output_from_no_past = model(
            next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
        output_from_past = model(next_tokens,
                                 past_key_values=past,
                                 attention_mask=attn_mask)["last_hidden_state"]

        # select random slice
        random_slice_idx = ids_tensor((1, ), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -1,
                                                        random_slice_idx].detach(
                                                        )
        output_from_past_slice = output_from_past[:, 0,
                                                  random_slice_idx].detach()

        # test that outputs are equal for slice
        self.parent.assertTrue(
            torch.allclose(output_from_past_slice,
                           output_from_no_past_slice,
                           atol=1e-3))
    def create_and_check_bloom_model(self, config, input_ids, input_mask, *args):
        model = BloomModel(config=config)
        model.to(torch_device)
        model.eval()

        result = model(input_ids)

        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(len(result.past_key_values), config.n_layer)
Beispiel #6
0
def convert_bloom_checkpoint_to_pytorch(bloom_checkpoint_path,
                                        bloom_config_file,
                                        pytorch_dump_folder_path, shard_model,
                                        pretraining_tp):
    # Construct model
    if bloom_config_file == "":
        config = BloomConfig()
    else:
        config = BloomConfig.from_json_file(bloom_config_file)

    if shard_model:
        file_names = os.listdir(bloom_checkpoint_path)
        file_names = list(
            sorted(
                filter(lambda s: s.startswith("layer") and "model_00" in s,
                       file_names)))

        index_dict = {"weight_map": {}, "metadata": {}}
        total_size = 0

        missing_keys = None

        config = BloomConfig()

        for j, file in enumerate(file_names):
            print("Processing file: {}".format(file))
            tensors = None

            for i in range(pretraining_tp):
                # load all TP files
                f_name = file.replace("model_00", f"model_0{i}")
                temp = torch.load(os.path.join(bloom_checkpoint_path, f_name),
                                  map_location="cpu")

                # Rename keys in the transformers names
                keys = list(temp.keys())
                for key in keys:
                    temp[layer_name_mapping(key, file)] = temp.pop(key)

                if tensors is None:
                    tensors = temp
                else:
                    for key in tensors.keys():
                        if any(
                                key.endswith(end)
                                for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
                            # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
                            tensors[key] += temp[key]
                        else:
                            # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
                            cat_dim = 1 if any(
                                text in key for text in
                                WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
                            # We concatenate these weights accross TP ranks
                            tensors[key] = torch.cat([tensors[key], temp[key]],
                                                     dim=cat_dim)

            # Divide by the number of TP the weights we want to average
            for key in tensors.keys():
                if any(
                        key.endswith(end)
                        for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
                    tensors[key] = tensors[key] / pretraining_tp
            torch.save(
                tensors,
                os.path.join(
                    pytorch_dump_folder_path,
                    "pytorch_model_{}-of-{}.bin".format(
                        str(j + 1).zfill(5),
                        str(len(file_names)).zfill(5)),
                ),
            )

            for key in tensors.keys():
                value = tensors[key]
                total_size += value.numel() * get_dtype_size(value.dtype)
                if key not in index_dict["weight_map"]:
                    index_dict["weight_map"][
                        key] = "pytorch_model_{}-of-{}.bin".format(
                            str(j + 1).zfill(5),
                            str(len(file_names)).zfill(5))

        config = BloomConfig()
        pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
        index_dict["metadata"]["total_size"] = total_size
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
            f.write(config.to_json_string())
        with open(os.path.join(pytorch_dump_folder_path,
                               WEIGHTS_NAME + ".index.json"),
                  "w",
                  encoding="utf-8") as f:
            json_config = json.dumps(index_dict, indent=2,
                                     sort_keys=True) + "\n"
            f.write(json_config)
    else:
        model = BloomModel(config)

        file_names = os.listdir(bloom_checkpoint_path)
        file_names = list(
            sorted(
                filter(lambda s: s.startswith("layer") and "model_00" in s,
                       file_names)))

        missing_keys = None
        for i, file in enumerate(file_names):
            tensors = None
            for i in range(pretraining_tp):
                # load all TP files
                f_name = file.replace("model_00", f"model_0{i}")
                temp = torch.load(os.path.join(bloom_checkpoint_path, f_name),
                                  map_location="cpu")

                # Rename keys in the transformers names
                keys = list(temp.keys())
                for key in keys:
                    temp[layer_name_mapping(key, file)] = temp.pop(key)

                if tensors is None:
                    tensors = temp
                else:
                    for key in tensors.keys():
                        # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
                        if any(
                                key.endswith(end)
                                for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
                            tensors[key] += temp[key]
                        else:
                            # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
                            cat_dim = 1 if any(
                                text in key for text in
                                WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
                            # We concatenate these weights accross TP ranks
                            tensors[key] = torch.cat([tensors[key], temp[key]],
                                                     dim=cat_dim)

            # Divide by the number of TP the weights we want to average
            for key in tensors.keys():
                if any(
                        key.endswith(end)
                        for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
                    tensors[key] = tensors[key] / pretraining_tp

            other_keys = model.load_state_dict(tensors, strict=False)
            assert not other_keys.unexpected_keys
            if missing_keys is None:
                missing_keys = set(other_keys.missing_keys)
            else:
                missing_keys = missing_keys.intersection(
                    set(other_keys.missing_keys))

        assert not missing_keys

        # Save pytorch-model
        os.makedirs(pytorch_dump_folder_path, exist_ok=True)
        pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
        pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
        print(
            f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}"
        )
        model = model.to(config.torch_dtype)
        torch.save(model.state_dict(), pytorch_weights_dump_path)
        print(f"Save configuration file to {pytorch_config_dump_path}")
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
            f.write(config.to_json_string())