def get_config(self):
     return LukeConfig(
         vocab_size=self.vocab_size,
         entity_vocab_size=self.entity_vocab_size,
         entity_emb_size=self.entity_emb_size,
         hidden_size=self.hidden_size,
         num_hidden_layers=self.num_hidden_layers,
         num_attention_heads=self.num_attention_heads,
         intermediate_size=self.intermediate_size,
         hidden_act=self.hidden_act,
         hidden_dropout_prob=self.hidden_dropout_prob,
         attention_probs_dropout_prob=self.attention_probs_dropout_prob,
         max_position_embeddings=self.max_position_embeddings,
         type_vocab_size=self.type_vocab_size,
         is_decoder=False,
         initializer_range=self.initializer_range,
         use_entity_aware_attention=self.use_entity_aware_attention,
     )
Example #2
0
def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path,
                            pytorch_dump_folder_path, model_size):
    # Load configuration defined in the metadata file
    with open(metadata_path) as metadata_file:
        metadata = json.load(metadata_file)
    config = LukeConfig(use_entity_aware_attention=True,
                        **metadata["model_config"])

    # Load in the weights from the checkpoint_path
    state_dict = torch.load(checkpoint_path, map_location="cpu")

    # Load the entity vocab file
    entity_vocab = load_entity_vocab(entity_vocab_path)

    tokenizer = RobertaTokenizer.from_pretrained(
        metadata["model_config"]["bert_model_name"])

    # Add special tokens to the token vocabulary for downstream tasks
    entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False)
    entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False)
    tokenizer.add_special_tokens(
        dict(additional_special_tokens=[entity_token_1, entity_token_2]))
    config.vocab_size += 2

    print(f"Saving tokenizer to {pytorch_dump_folder_path}")
    tokenizer.save_pretrained(pytorch_dump_folder_path)
    with open(
            os.path.join(pytorch_dump_folder_path,
                         LukeTokenizer.vocab_files_names["entity_vocab_file"]),
            "w") as f:
        json.dump(entity_vocab, f)

    tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path)

    # Initialize the embeddings of the special tokens
    word_emb = state_dict["embeddings.word_embeddings.weight"]
    ent_emb = word_emb[tokenizer.convert_tokens_to_ids(["@"])[0]].unsqueeze(0)
    ent2_emb = word_emb[tokenizer.convert_tokens_to_ids(["#"])[0]].unsqueeze(0)
    state_dict["embeddings.word_embeddings.weight"] = torch.cat(
        [word_emb, ent_emb, ent2_emb])

    # Initialize the query layers of the entity-aware self-attention mechanism
    for layer_index in range(config.num_hidden_layers):
        for matrix_name in ["query.weight", "query.bias"]:
            prefix = f"encoder.layer.{layer_index}.attention.self."
            state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix +
                                                                   matrix_name]
            state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix +
                                                                   matrix_name]
            state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix +
                                                                   matrix_name]

    # Initialize the embedding of the [MASK2] entity using that of the [MASK] entity for downstream tasks
    entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"]
    entity_emb[entity_vocab["[MASK2]"]] = entity_emb[entity_vocab["[MASK]"]]

    model = LukeModel(config=config).eval()

    missing_keys, unexpected_keys = model.load_state_dict(state_dict,
                                                          strict=False)
    if not (len(missing_keys) == 1
            and missing_keys[0] == "embeddings.position_ids"):
        raise ValueError(
            f"Missing keys {', '.join(missing_keys)}. Expected only missing embeddings.position_ids"
        )
    if not (all(
            key.startswith("entity_predictions") or key.startswith("lm_head")
            for key in unexpected_keys)):
        raise ValueError(
            f"Unexpected keys {', '.join([key for key in unexpected_keys if not (key.startswith('entity_predictions') or key.startswith('lm_head'))])}"
        )

    # Check outputs
    tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path,
                                              task="entity_classification")

    text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
    span = (39, 42)
    encoding = tokenizer(text,
                         entity_spans=[span],
                         add_prefix_space=True,
                         return_tensors="pt")

    outputs = model(**encoding)

    # Verify word hidden states
    if model_size == "large":
        expected_shape = torch.Size((1, 42, 1024))
        expected_slice = torch.tensor([[0.0133, 0.0865, 0.0095],
                                       [0.3093, -0.2576, -0.7418],
                                       [-0.1720, -0.2117, -0.2869]])
    else:  # base
        expected_shape = torch.Size((1, 42, 768))
        expected_slice = torch.tensor([[0.0037, 0.1368, -0.0091],
                                       [0.1099, 0.3329, -0.1095],
                                       [0.0765, 0.5335, 0.1179]])

    if not (outputs.last_hidden_state.shape == expected_shape):
        raise ValueError(
            f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}"
        )
    if not torch.allclose(
            outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
        raise ValueError

    # Verify entity hidden states
    if model_size == "large":
        expected_shape = torch.Size((1, 1, 1024))
        expected_slice = torch.tensor([[0.0466, -0.0106, -0.0179]])
    else:  # base
        expected_shape = torch.Size((1, 1, 768))
        expected_slice = torch.tensor([[0.1457, 0.1044, 0.0174]])

    if not (outputs.entity_last_hidden_state.shape != expected_shape):
        raise ValueError(
            f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}"
        )
    if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3],
                          expected_slice,
                          atol=1e-4):
        raise ValueError

    # Finally, save our PyTorch model and tokenizer
    print("Saving PyTorch model to {}".format(pytorch_dump_folder_path))
    model.save_pretrained(pytorch_dump_folder_path)
def main():
    args = parse_args()

    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    handler = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[handler])
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state)

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.push_to_hub:
            if args.hub_model_id is None:
                repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
            else:
                repo_name = args.hub_model_id
            repo = Repository(args.output_dir, clone_from=repo_name)
        elif args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called
    # 'tokens' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
    else:
        data_files = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = args.train_file.split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files)
    # Trim a number of training examples
    if args.debug:
        for split in raw_datasets.keys():
            raw_datasets[split] = raw_datasets[split].select(range(100))
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    if raw_datasets["train"] is not None:
        column_names = raw_datasets["train"].column_names
        features = raw_datasets["train"].features
    else:
        column_names = raw_datasets["validation"].column_names
        features = raw_datasets["validation"].features

    if args.text_column_name is not None:
        text_column_name = args.text_column_name
    elif "tokens" in column_names:
        text_column_name = "tokens"
    else:
        text_column_name = column_names[0]

    if args.label_column_name is not None:
        label_column_name = args.label_column_name
    elif f"{args.task_name}_tags" in column_names:
        label_column_name = f"{args.task_name}_tags"
    else:
        label_column_name = column_names[1]

    # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
    # unique labels.
    def get_label_list(labels):
        unique_labels = set()
        for label in labels:
            unique_labels = unique_labels | set(label)
        label_list = list(unique_labels)
        label_list.sort()
        return label_list

    if isinstance(features[label_column_name].feature, ClassLabel):
        label_list = features[label_column_name].feature.names
        # No need to convert the labels since they are already ints.
    else:
        label_list = get_label_list(raw_datasets["train"][label_column_name])
    num_labels = len(label_list)

    # Map that sends B-Xxx label to its I-Xxx counterpart
    b_to_i_label = []

    for idx, label in enumerate(label_list):
        if label.startswith("B-") and label.replace("B-", "I-") in label_list:
            b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
        else:
            b_to_i_label.append(idx)

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if args.config_name:
        config = LukeConfig.from_pretrained(args.config_name, num_labels=num_labels)
    elif args.model_name_or_path:
        config = LukeConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels)
    else:
        logger.warning("You are instantiating a new config instance from scratch.")

    tokenizer_name_or_path = args.tokenizer_name if args.tokenizer_name else args.model_name_or_path
    if not tokenizer_name_or_path:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    tokenizer = LukeTokenizer.from_pretrained(
        tokenizer_name_or_path,
        use_fast=False,
        task="entity_span_classification",
        max_entity_length=args.max_entity_length,
        max_mention_length=args.max_mention_length,
    )

    if args.model_name_or_path:
        model = LukeForEntitySpanClassification.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
        )
    else:
        logger.info("Training new model from scratch")
        model = LukeForEntitySpanClassification.from_config(config)

    model.resize_token_embeddings(len(tokenizer))

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    padding = "max_length" if args.pad_to_max_length else False

    def compute_sentence_boundaries_for_luke(examples):
        sentence_boundaries = []

        for tokens in examples[text_column_name]:
            sentence_boundaries.append([0, len(tokens)])

        examples["sentence_boundaries"] = sentence_boundaries

        return examples

    def compute_entity_spans_for_luke(examples):
        all_entity_spans = []
        texts = []
        all_labels_entity_spans = []
        all_original_entity_spans = []

        for labels, tokens, sentence_boundaries in zip(
            examples[label_column_name], examples[text_column_name], examples["sentence_boundaries"]
        ):
            subword_lengths = [len(tokenizer.tokenize(token)) for token in tokens]
            total_subword_length = sum(subword_lengths)
            _, context_end = sentence_boundaries

            if total_subword_length > args.max_length - 2:
                cur_length = sum(subword_lengths[:context_end])
                idx = context_end - 1

                while cur_length > args.max_length - 2:
                    cur_length -= subword_lengths[idx]
                    context_end -= 1
                    idx -= 1

            text = ""
            sentence_words = tokens[:context_end]
            sentence_subword_lengths = subword_lengths[:context_end]
            word_start_char_positions = []
            word_end_char_positions = []
            labels_positions = {}

            for word, label in zip(sentence_words, labels):
                if word[0] == "'" or (len(word) == 1 and is_punctuation(word)):
                    text = text.rstrip()

                word_start_char_positions.append(len(text))
                text += word
                word_end_char_positions.append(len(text))
                text += " "
                labels_positions[(word_start_char_positions[-1], word_end_char_positions[-1])] = label

            text = text.rstrip()
            texts.append(text)
            entity_spans = []
            labels_entity_spans = []
            original_entity_spans = []

            for word_start in range(len(sentence_words)):
                for word_end in range(word_start, len(sentence_words)):
                    if (
                        sum(sentence_subword_lengths[word_start:word_end]) <= tokenizer.max_mention_length
                        and len(entity_spans) < tokenizer.max_entity_length
                    ):
                        entity_spans.append((word_start_char_positions[word_start], word_end_char_positions[word_end]))
                        original_entity_spans.append((word_start, word_end + 1))
                        if (
                            word_start_char_positions[word_start],
                            word_end_char_positions[word_end],
                        ) in labels_positions:
                            labels_entity_spans.append(
                                labels_positions[
                                    (word_start_char_positions[word_start], word_end_char_positions[word_end])
                                ]
                            )
                        else:
                            labels_entity_spans.append(0)

            all_entity_spans.append(entity_spans)
            all_labels_entity_spans.append(labels_entity_spans)
            all_original_entity_spans.append(original_entity_spans)

        examples["entity_spans"] = all_entity_spans
        examples["text"] = texts
        examples["labels_entity_spans"] = all_labels_entity_spans
        examples["original_entity_spans"] = all_original_entity_spans

        return examples

    def tokenize_and_align_labels(examples):
        entity_spans = []

        for v in examples["entity_spans"]:
            entity_spans.append(list(map(tuple, v)))

        tokenized_inputs = tokenizer(
            examples["text"],
            entity_spans=entity_spans,
            max_length=args.max_length,
            padding=padding,
            truncation=True,
        )

        if padding == "max_length":
            tokenized_inputs["labels"] = padding_tensor(
                examples["labels_entity_spans"], -100, tokenizer.padding_side, tokenizer.max_entity_length
            )
            tokenized_inputs["original_entity_spans"] = padding_tensor(
                examples["original_entity_spans"], (-1, -1), tokenizer.padding_side, tokenizer.max_entity_length
            )
            tokenized_inputs[label_column_name] = padding_tensor(
                examples[label_column_name], -1, tokenizer.padding_side, tokenizer.max_entity_length
            )
        else:
            tokenized_inputs["labels"] = [ex[: tokenizer.max_entity_length] for ex in examples["labels_entity_spans"]]
            tokenized_inputs["original_entity_spans"] = [
                ex[: tokenizer.max_entity_length] for ex in examples["original_entity_spans"]
            ]
            tokenized_inputs[label_column_name] = [
                ex[: tokenizer.max_entity_length] for ex in examples[label_column_name]
            ]

        return tokenized_inputs

    with accelerator.main_process_first():
        raw_datasets = raw_datasets.map(
            compute_sentence_boundaries_for_luke,
            batched=True,
            desc="Adding sentence boundaries",
        )
        raw_datasets = raw_datasets.map(
            compute_entity_spans_for_luke,
            batched=True,
            desc="Adding sentence spans",
        )

        processed_raw_datasets = raw_datasets.map(
            tokenize_and_align_labels,
            batched=True,
            remove_columns=raw_datasets["train"].column_names,
            desc="Running tokenizer on dataset",
        )

    train_dataset = processed_raw_datasets["train"]
    eval_dataset = processed_raw_datasets["validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # DataLoaders creation:
    if args.pad_to_max_length:
        # If padding was already done ot max length, we use the default data collator that will just convert everything
        # to tensors.
        data_collator = default_data_collator
    else:
        # Otherwise, `DataCollatorForTokenClassification` will apply dynamic padding for us (by padding to the maximum length of
        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
        data_collator = DataCollatorForLukeTokenClassification(
            tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)
        )

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
    )
    eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # Use the device given by the `accelerator` object.
    device = accelerator.device
    model.to(device)

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )

    # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
    # shorter in multiprocess)

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Metrics
    metric = load_metric("seqeval")

    def get_luke_labels(outputs, ner_tags, original_entity_spans):
        true_predictions = []
        true_labels = []

        for output, original_spans, tags in zip(outputs.logits, original_entity_spans, ner_tags):
            true_tags = [val for val in tags if val != -1]
            true_original_spans = [val for val in original_spans if val != (-1, -1)]
            max_indices = torch.argmax(output, axis=1)
            max_logits = torch.max(output, axis=1).values
            predictions = []

            for logit, index, span in zip(max_logits, max_indices, true_original_spans):
                if index != 0:
                    predictions.append((logit, span, label_list[index]))

            predicted_sequence = [label_list[0]] * len(true_tags)

            for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):
                if all([o == label_list[0] for o in predicted_sequence[span[0] : span[1]]]):
                    predicted_sequence[span[0]] = label
                    if span[1] - span[0] > 1:
                        predicted_sequence[span[0] + 1 : span[1]] = [label] * (span[1] - span[0] - 1)

            true_predictions.append(predicted_sequence)
            true_labels.append([label_list[tag_id] for tag_id in true_tags])

        return true_predictions, true_labels

    def compute_metrics():
        results = metric.compute()
        if args.return_entity_level_metrics:
            # Unpack nested dictionaries
            final_results = {}
            for key, value in results.items():
                if isinstance(value, dict):
                    for n, v in value.items():
                        final_results[f"{key}_{n}"] = v
                else:
                    final_results[key] = value
            return final_results
        else:
            return {
                "precision": results["overall_precision"],
                "recall": results["overall_recall"],
                "f1": results["overall_f1"],
                "accuracy": results["overall_accuracy"],
            }

    # Train!
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    completed_steps = 0

    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            _ = batch.pop("original_entity_spans")
            outputs = model(**batch)
            loss = outputs.loss
            loss = loss / args.gradient_accumulation_steps
            accelerator.backward(loss)
            if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

            if completed_steps >= args.max_train_steps:
                break

        model.eval()
        for step, batch in enumerate(eval_dataloader):
            original_entity_spans = batch.pop("original_entity_spans")
            with torch.no_grad():
                outputs = model(**batch)

            preds, refs = get_luke_labels(outputs, batch[label_column_name], original_entity_spans)

            metric.add_batch(
                predictions=preds,
                references=refs,
            )  # predictions and preferences are expected to be a nested list of labels, not label_ids

        eval_metric = compute_metrics()
        accelerator.print(f"epoch {epoch}:", eval_metric)

        if args.push_to_hub and epoch < args.num_train_epochs - 1:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
            if accelerator.is_main_process:
                tokenizer.save_pretrained(args.output_dir)
                repo.push_to_hub(
                    commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
                )

    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
        if accelerator.is_main_process:
            tokenizer.save_pretrained(args.output_dir)
            if args.push_to_hub:
                repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
    def prepare_config_and_inputs(self):
        # prepare words
        input_ids = ids_tensor([self.batch_size, self.seq_length],
                               self.vocab_size)

        attention_mask = None
        if self.use_attention_mask:
            attention_mask = random_attention_mask(
                [self.batch_size, self.seq_length])

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length],
                                        self.type_vocab_size)

        # prepare entities
        entity_ids = ids_tensor([self.batch_size, self.entity_length],
                                self.entity_vocab_size)

        entity_attention_mask = None
        if self.use_entity_attention_mask:
            entity_attention_mask = random_attention_mask(
                [self.batch_size, self.entity_length])

        entity_token_type_ids = None
        if self.use_token_type_ids:
            entity_token_type_ids = ids_tensor(
                [self.batch_size, self.entity_length], self.type_vocab_size)

        entity_position_ids = None
        if self.use_entity_position_ids:
            entity_position_ids = ids_tensor(
                [self.batch_size, self.entity_length, self.mention_length],
                self.mention_length)

        sequence_labels = None
        entity_classification_labels = None
        entity_pair_classification_labels = None
        entity_span_classification_labels = None

        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size],
                                         self.type_sequence_label_size)
            entity_classification_labels = ids_tensor(
                [self.batch_size], self.num_entity_classification_labels)
            entity_pair_classification_labels = ids_tensor(
                [self.batch_size], self.num_entity_pair_classification_labels)
            entity_span_classification_labels = ids_tensor(
                [self.batch_size, self.entity_length],
                self.num_entity_span_classification_labels)

        config = LukeConfig(
            vocab_size=self.vocab_size,
            entity_vocab_size=self.entity_vocab_size,
            entity_emb_size=self.entity_emb_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            is_decoder=False,
            initializer_range=self.initializer_range,
            use_entity_aware_attention=self.use_entity_aware_attention,
        )

        return (
            config,
            input_ids,
            attention_mask,
            token_type_ids,
            entity_ids,
            entity_attention_mask,
            entity_token_type_ids,
            entity_position_ids,
            sequence_labels,
            entity_classification_labels,
            entity_pair_classification_labels,
            entity_span_classification_labels,
        )
Example #5
0
def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path,
                            pytorch_dump_folder_path, model_size):
    # Load configuration defined in the metadata file
    with open(metadata_path) as metadata_file:
        metadata = json.load(metadata_file)
    config = LukeConfig(use_entity_aware_attention=True,
                        **metadata["model_config"])

    # Load in the weights from the checkpoint_path
    state_dict = torch.load(checkpoint_path, map_location="cpu")["module"]

    # Load the entity vocab file
    entity_vocab = load_original_entity_vocab(entity_vocab_path)
    # add an entry for [MASK2]
    entity_vocab["[MASK2]"] = max(entity_vocab.values()) + 1
    config.entity_vocab_size += 1

    tokenizer = XLMRobertaTokenizer.from_pretrained(
        metadata["model_config"]["bert_model_name"])

    # Add special tokens to the token vocabulary for downstream tasks
    entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False)
    entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False)
    tokenizer.add_special_tokens(
        dict(additional_special_tokens=[entity_token_1, entity_token_2]))
    config.vocab_size += 2

    print(f"Saving tokenizer to {pytorch_dump_folder_path}")
    tokenizer.save_pretrained(pytorch_dump_folder_path)
    with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"),
              "r") as f:
        tokenizer_config = json.load(f)
    tokenizer_config["tokenizer_class"] = "MLukeTokenizer"
    with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"),
              "w") as f:
        json.dump(tokenizer_config, f)

    with open(
            os.path.join(
                pytorch_dump_folder_path,
                MLukeTokenizer.vocab_files_names["entity_vocab_file"]),
            "w") as f:
        json.dump(entity_vocab, f)

    tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)

    # Initialize the embeddings of the special tokens
    ent_init_index = tokenizer.convert_tokens_to_ids(["@"])[0]
    ent2_init_index = tokenizer.convert_tokens_to_ids(["#"])[0]

    word_emb = state_dict["embeddings.word_embeddings.weight"]
    ent_emb = word_emb[ent_init_index].unsqueeze(0)
    ent2_emb = word_emb[ent2_init_index].unsqueeze(0)
    state_dict["embeddings.word_embeddings.weight"] = torch.cat(
        [word_emb, ent_emb, ent2_emb])
    # add special tokens for 'entity_predictions.bias'
    for bias_name in ["lm_head.decoder.bias", "lm_head.bias"]:
        decoder_bias = state_dict[bias_name]
        ent_decoder_bias = decoder_bias[ent_init_index].unsqueeze(0)
        ent2_decoder_bias = decoder_bias[ent2_init_index].unsqueeze(0)
        state_dict[bias_name] = torch.cat(
            [decoder_bias, ent_decoder_bias, ent2_decoder_bias])

    # Initialize the query layers of the entity-aware self-attention mechanism
    for layer_index in range(config.num_hidden_layers):
        for matrix_name in ["query.weight", "query.bias"]:
            prefix = f"encoder.layer.{layer_index}.attention.self."
            state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix +
                                                                   matrix_name]
            state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix +
                                                                   matrix_name]
            state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix +
                                                                   matrix_name]

    # Initialize the embedding of the [MASK2] entity using that of the [MASK] entity for downstream tasks
    entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"]
    entity_mask_emb = entity_emb[entity_vocab["[MASK]"]].unsqueeze(0)
    state_dict["entity_embeddings.entity_embeddings.weight"] = torch.cat(
        [entity_emb, entity_mask_emb])
    # add [MASK2] for 'entity_predictions.bias'
    entity_prediction_bias = state_dict["entity_predictions.bias"]
    entity_mask_bias = entity_prediction_bias[
        entity_vocab["[MASK]"]].unsqueeze(0)
    state_dict["entity_predictions.bias"] = torch.cat(
        [entity_prediction_bias, entity_mask_bias])

    model = LukeForMaskedLM(config=config).eval()

    state_dict.pop("entity_predictions.decoder.weight")
    state_dict.pop("lm_head.decoder.weight")
    state_dict.pop("lm_head.decoder.bias")
    state_dict_for_hugging_face = OrderedDict()
    for key, value in state_dict.items():
        if not (key.startswith("lm_head")
                or key.startswith("entity_predictions")):
            state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key]
        else:
            state_dict_for_hugging_face[key] = state_dict[key]

    missing_keys, unexpected_keys = model.load_state_dict(
        state_dict_for_hugging_face, strict=False)

    if set(unexpected_keys) != {"luke.embeddings.position_ids"}:
        raise ValueError(f"Unexpected unexpected_keys: {unexpected_keys}")
    if set(missing_keys) != {
            "lm_head.decoder.weight",
            "lm_head.decoder.bias",
            "entity_predictions.decoder.weight",
    }:
        raise ValueError(f"Unexpected missing_keys: {missing_keys}")

    model.tie_weights()
    assert (model.luke.embeddings.word_embeddings.weight ==
            model.lm_head.decoder.weight).all()
    assert (model.luke.entity_embeddings.entity_embeddings.weight ==
            model.entity_predictions.decoder.weight).all()

    # Check outputs
    tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path,
                                               task="entity_classification")

    text = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
    span = (0, 9)
    encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")

    outputs = model(**encoding)

    # Verify word hidden states
    if model_size == "large":
        raise NotImplementedError
    else:  # base
        expected_shape = torch.Size((1, 33, 768))
        expected_slice = torch.tensor([[0.0892, 0.0596, -0.2819],
                                       [0.0134, 0.1199, 0.0573],
                                       [-0.0169, 0.0927, 0.0644]])

    if not (outputs.last_hidden_state.shape == expected_shape):
        raise ValueError(
            f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}"
        )
    if not torch.allclose(
            outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
        raise ValueError

    # Verify entity hidden states
    if model_size == "large":
        raise NotImplementedError
    else:  # base
        expected_shape = torch.Size((1, 1, 768))
        expected_slice = torch.tensor([[-0.1482, 0.0609, 0.0322]])

    if not (outputs.entity_last_hidden_state.shape == expected_shape):
        raise ValueError(
            f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is"
            f" {expected_shape}")
    if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3],
                          expected_slice,
                          atol=1e-4):
        raise ValueError

    # Verify masked word/entity prediction
    tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)
    text = "Tokyo is the capital of <mask>."
    span = (24, 30)
    encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")

    outputs = model(**encoding)

    input_ids = encoding["input_ids"][0].tolist()
    mask_position_id = input_ids.index(
        tokenizer.convert_tokens_to_ids("<mask>"))
    predicted_id = outputs.logits[0][mask_position_id].argmax(dim=-1)
    assert "Japan" == tokenizer.decode(predicted_id)

    predicted_entity_id = outputs.entity_logits[0][0].argmax().item()
    multilingual_predicted_entities = [
        entity for entity, entity_id in tokenizer.entity_vocab.items()
        if entity_id == predicted_entity_id
    ]
    assert [e for e in multilingual_predicted_entities
            if e.startswith("en:")][0] == "en:Japan"

    # Finally, save our PyTorch model and tokenizer
    print("Saving PyTorch model to {}".format(pytorch_dump_folder_path))
    model.save_pretrained(pytorch_dump_folder_path)