def get_config(self): return LukeConfig( vocab_size=self.vocab_size, entity_vocab_size=self.entity_vocab_size, entity_emb_size=self.entity_emb_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, is_decoder=False, initializer_range=self.initializer_range, use_entity_aware_attention=self.use_entity_aware_attention, )
def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, pytorch_dump_folder_path, model_size): # Load configuration defined in the metadata file with open(metadata_path) as metadata_file: metadata = json.load(metadata_file) config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"]) # Load in the weights from the checkpoint_path state_dict = torch.load(checkpoint_path, map_location="cpu") # Load the entity vocab file entity_vocab = load_entity_vocab(entity_vocab_path) tokenizer = RobertaTokenizer.from_pretrained( metadata["model_config"]["bert_model_name"]) # Add special tokens to the token vocabulary for downstream tasks entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False) entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False) tokenizer.add_special_tokens( dict(additional_special_tokens=[entity_token_1, entity_token_2])) config.vocab_size += 2 print(f"Saving tokenizer to {pytorch_dump_folder_path}") tokenizer.save_pretrained(pytorch_dump_folder_path) with open( os.path.join(pytorch_dump_folder_path, LukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f: json.dump(entity_vocab, f) tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path) # Initialize the embeddings of the special tokens word_emb = state_dict["embeddings.word_embeddings.weight"] ent_emb = word_emb[tokenizer.convert_tokens_to_ids(["@"])[0]].unsqueeze(0) ent2_emb = word_emb[tokenizer.convert_tokens_to_ids(["#"])[0]].unsqueeze(0) state_dict["embeddings.word_embeddings.weight"] = torch.cat( [word_emb, ent_emb, ent2_emb]) # Initialize the query layers of the entity-aware self-attention mechanism for layer_index in range(config.num_hidden_layers): for matrix_name in ["query.weight", "query.bias"]: prefix = f"encoder.layer.{layer_index}.attention.self." state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name] state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name] state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name] # Initialize the embedding of the [MASK2] entity using that of the [MASK] entity for downstream tasks entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"] entity_emb[entity_vocab["[MASK2]"]] = entity_emb[entity_vocab["[MASK]"]] model = LukeModel(config=config).eval() missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if not (len(missing_keys) == 1 and missing_keys[0] == "embeddings.position_ids"): raise ValueError( f"Missing keys {', '.join(missing_keys)}. Expected only missing embeddings.position_ids" ) if not (all( key.startswith("entity_predictions") or key.startswith("lm_head") for key in unexpected_keys)): raise ValueError( f"Unexpected keys {', '.join([key for key in unexpected_keys if not (key.startswith('entity_predictions') or key.startswith('lm_head'))])}" ) # Check outputs tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification") text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." span = (39, 42) encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt") outputs = model(**encoding) # Verify word hidden states if model_size == "large": expected_shape = torch.Size((1, 42, 1024)) expected_slice = torch.tensor([[0.0133, 0.0865, 0.0095], [0.3093, -0.2576, -0.7418], [-0.1720, -0.2117, -0.2869]]) else: # base expected_shape = torch.Size((1, 42, 768)) expected_slice = torch.tensor([[0.0037, 0.1368, -0.0091], [0.1099, 0.3329, -0.1095], [0.0765, 0.5335, 0.1179]]) if not (outputs.last_hidden_state.shape == expected_shape): raise ValueError( f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}" ) if not torch.allclose( outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4): raise ValueError # Verify entity hidden states if model_size == "large": expected_shape = torch.Size((1, 1, 1024)) expected_slice = torch.tensor([[0.0466, -0.0106, -0.0179]]) else: # base expected_shape = torch.Size((1, 1, 768)) expected_slice = torch.tensor([[0.1457, 0.1044, 0.0174]]) if not (outputs.entity_last_hidden_state.shape != expected_shape): raise ValueError( f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}" ) if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4): raise ValueError # Finally, save our PyTorch model and tokenizer print("Saving PyTorch model to {}".format(pytorch_dump_folder_path)) model.save_pretrained(pytorch_dump_folder_path)
def main(): args = parse_args() # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. handler = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(kwargs_handlers=[handler]) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state) # Setup logging, we only want one process per machine to log things on the screen. # accelerator.is_local_main_process is only True for one process per machine. logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.push_to_hub: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) accelerator.wait_for_everyone() # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called # 'tokens' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) else: data_files = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: data_files["validation"] = args.validation_file extension = args.train_file.split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) # Trim a number of training examples if args.debug: for split in raw_datasets.keys(): raw_datasets[split] = raw_datasets[split].select(range(100)) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. if raw_datasets["train"] is not None: column_names = raw_datasets["train"].column_names features = raw_datasets["train"].features else: column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features if args.text_column_name is not None: text_column_name = args.text_column_name elif "tokens" in column_names: text_column_name = "tokens" else: text_column_name = column_names[0] if args.label_column_name is not None: label_column_name = args.label_column_name elif f"{args.task_name}_tags" in column_names: label_column_name = f"{args.task_name}_tags" else: label_column_name = column_names[1] # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # unique labels. def get_label_list(labels): unique_labels = set() for label in labels: unique_labels = unique_labels | set(label) label_list = list(unique_labels) label_list.sort() return label_list if isinstance(features[label_column_name].feature, ClassLabel): label_list = features[label_column_name].feature.names # No need to convert the labels since they are already ints. else: label_list = get_label_list(raw_datasets["train"][label_column_name]) num_labels = len(label_list) # Map that sends B-Xxx label to its I-Xxx counterpart b_to_i_label = [] for idx, label in enumerate(label_list): if label.startswith("B-") and label.replace("B-", "I-") in label_list: b_to_i_label.append(label_list.index(label.replace("B-", "I-"))) else: b_to_i_label.append(idx) # Load pretrained model and tokenizer # # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. if args.config_name: config = LukeConfig.from_pretrained(args.config_name, num_labels=num_labels) elif args.model_name_or_path: config = LukeConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels) else: logger.warning("You are instantiating a new config instance from scratch.") tokenizer_name_or_path = args.tokenizer_name if args.tokenizer_name else args.model_name_or_path if not tokenizer_name_or_path: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) tokenizer = LukeTokenizer.from_pretrained( tokenizer_name_or_path, use_fast=False, task="entity_span_classification", max_entity_length=args.max_entity_length, max_mention_length=args.max_mention_length, ) if args.model_name_or_path: model = LukeForEntitySpanClassification.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, ) else: logger.info("Training new model from scratch") model = LukeForEntitySpanClassification.from_config(config) model.resize_token_embeddings(len(tokenizer)) # Preprocessing the datasets. # First we tokenize all the texts. padding = "max_length" if args.pad_to_max_length else False def compute_sentence_boundaries_for_luke(examples): sentence_boundaries = [] for tokens in examples[text_column_name]: sentence_boundaries.append([0, len(tokens)]) examples["sentence_boundaries"] = sentence_boundaries return examples def compute_entity_spans_for_luke(examples): all_entity_spans = [] texts = [] all_labels_entity_spans = [] all_original_entity_spans = [] for labels, tokens, sentence_boundaries in zip( examples[label_column_name], examples[text_column_name], examples["sentence_boundaries"] ): subword_lengths = [len(tokenizer.tokenize(token)) for token in tokens] total_subword_length = sum(subword_lengths) _, context_end = sentence_boundaries if total_subword_length > args.max_length - 2: cur_length = sum(subword_lengths[:context_end]) idx = context_end - 1 while cur_length > args.max_length - 2: cur_length -= subword_lengths[idx] context_end -= 1 idx -= 1 text = "" sentence_words = tokens[:context_end] sentence_subword_lengths = subword_lengths[:context_end] word_start_char_positions = [] word_end_char_positions = [] labels_positions = {} for word, label in zip(sentence_words, labels): if word[0] == "'" or (len(word) == 1 and is_punctuation(word)): text = text.rstrip() word_start_char_positions.append(len(text)) text += word word_end_char_positions.append(len(text)) text += " " labels_positions[(word_start_char_positions[-1], word_end_char_positions[-1])] = label text = text.rstrip() texts.append(text) entity_spans = [] labels_entity_spans = [] original_entity_spans = [] for word_start in range(len(sentence_words)): for word_end in range(word_start, len(sentence_words)): if ( sum(sentence_subword_lengths[word_start:word_end]) <= tokenizer.max_mention_length and len(entity_spans) < tokenizer.max_entity_length ): entity_spans.append((word_start_char_positions[word_start], word_end_char_positions[word_end])) original_entity_spans.append((word_start, word_end + 1)) if ( word_start_char_positions[word_start], word_end_char_positions[word_end], ) in labels_positions: labels_entity_spans.append( labels_positions[ (word_start_char_positions[word_start], word_end_char_positions[word_end]) ] ) else: labels_entity_spans.append(0) all_entity_spans.append(entity_spans) all_labels_entity_spans.append(labels_entity_spans) all_original_entity_spans.append(original_entity_spans) examples["entity_spans"] = all_entity_spans examples["text"] = texts examples["labels_entity_spans"] = all_labels_entity_spans examples["original_entity_spans"] = all_original_entity_spans return examples def tokenize_and_align_labels(examples): entity_spans = [] for v in examples["entity_spans"]: entity_spans.append(list(map(tuple, v))) tokenized_inputs = tokenizer( examples["text"], entity_spans=entity_spans, max_length=args.max_length, padding=padding, truncation=True, ) if padding == "max_length": tokenized_inputs["labels"] = padding_tensor( examples["labels_entity_spans"], -100, tokenizer.padding_side, tokenizer.max_entity_length ) tokenized_inputs["original_entity_spans"] = padding_tensor( examples["original_entity_spans"], (-1, -1), tokenizer.padding_side, tokenizer.max_entity_length ) tokenized_inputs[label_column_name] = padding_tensor( examples[label_column_name], -1, tokenizer.padding_side, tokenizer.max_entity_length ) else: tokenized_inputs["labels"] = [ex[: tokenizer.max_entity_length] for ex in examples["labels_entity_spans"]] tokenized_inputs["original_entity_spans"] = [ ex[: tokenizer.max_entity_length] for ex in examples["original_entity_spans"] ] tokenized_inputs[label_column_name] = [ ex[: tokenizer.max_entity_length] for ex in examples[label_column_name] ] return tokenized_inputs with accelerator.main_process_first(): raw_datasets = raw_datasets.map( compute_sentence_boundaries_for_luke, batched=True, desc="Adding sentence boundaries", ) raw_datasets = raw_datasets.map( compute_entity_spans_for_luke, batched=True, desc="Adding sentence spans", ) processed_raw_datasets = raw_datasets.map( tokenize_and_align_labels, batched=True, remove_columns=raw_datasets["train"].column_names, desc="Running tokenizer on dataset", ) train_dataset = processed_raw_datasets["train"] eval_dataset = processed_raw_datasets["validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # DataLoaders creation: if args.pad_to_max_length: # If padding was already done ot max length, we use the default data collator that will just convert everything # to tensors. data_collator = default_data_collator else: # Otherwise, `DataCollatorForTokenClassification` will apply dynamic padding for us (by padding to the maximum length of # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). data_collator = DataCollatorForLukeTokenClassification( tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None) ) train_dataloader = DataLoader( train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size ) eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) # Optimizer # Split weights in two groups, one with weight decay and the other not. no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) # Use the device given by the `accelerator` object. device = accelerator.device model.to(device) # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader ) # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be # shorter in multiprocess) # Scheduler and math around the number of training steps. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch else: args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) # Metrics metric = load_metric("seqeval") def get_luke_labels(outputs, ner_tags, original_entity_spans): true_predictions = [] true_labels = [] for output, original_spans, tags in zip(outputs.logits, original_entity_spans, ner_tags): true_tags = [val for val in tags if val != -1] true_original_spans = [val for val in original_spans if val != (-1, -1)] max_indices = torch.argmax(output, axis=1) max_logits = torch.max(output, axis=1).values predictions = [] for logit, index, span in zip(max_logits, max_indices, true_original_spans): if index != 0: predictions.append((logit, span, label_list[index])) predicted_sequence = [label_list[0]] * len(true_tags) for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True): if all([o == label_list[0] for o in predicted_sequence[span[0] : span[1]]]): predicted_sequence[span[0]] = label if span[1] - span[0] > 1: predicted_sequence[span[0] + 1 : span[1]] = [label] * (span[1] - span[0] - 1) true_predictions.append(predicted_sequence) true_labels.append([label_list[tag_id] for tag_id in true_tags]) return true_predictions, true_labels def compute_metrics(): results = metric.compute() if args.return_entity_level_metrics: # Unpack nested dictionaries final_results = {} for key, value in results.items(): if isinstance(value, dict): for n, v in value.items(): final_results[f"{key}_{n}"] = v else: final_results[key] = value return final_results else: return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], } # Train! total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 for epoch in range(args.num_train_epochs): model.train() for step, batch in enumerate(train_dataloader): _ = batch.pop("original_entity_spans") outputs = model(**batch) loss = outputs.loss loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1) completed_steps += 1 if completed_steps >= args.max_train_steps: break model.eval() for step, batch in enumerate(eval_dataloader): original_entity_spans = batch.pop("original_entity_spans") with torch.no_grad(): outputs = model(**batch) preds, refs = get_luke_labels(outputs, batch[label_column_name], original_entity_spans) metric.add_batch( predictions=preds, references=refs, ) # predictions and preferences are expected to be a nested list of labels, not label_ids eval_metric = compute_metrics() accelerator.print(f"epoch {epoch}:", eval_metric) if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True ) if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
def prepare_config_and_inputs(self): # prepare words input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) attention_mask = None if self.use_attention_mask: attention_mask = random_attention_mask( [self.batch_size, self.seq_length]) token_type_ids = None if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) # prepare entities entity_ids = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size) entity_attention_mask = None if self.use_entity_attention_mask: entity_attention_mask = random_attention_mask( [self.batch_size, self.entity_length]) entity_token_type_ids = None if self.use_token_type_ids: entity_token_type_ids = ids_tensor( [self.batch_size, self.entity_length], self.type_vocab_size) entity_position_ids = None if self.use_entity_position_ids: entity_position_ids = ids_tensor( [self.batch_size, self.entity_length, self.mention_length], self.mention_length) sequence_labels = None entity_classification_labels = None entity_pair_classification_labels = None entity_span_classification_labels = None if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) entity_classification_labels = ids_tensor( [self.batch_size], self.num_entity_classification_labels) entity_pair_classification_labels = ids_tensor( [self.batch_size], self.num_entity_pair_classification_labels) entity_span_classification_labels = ids_tensor( [self.batch_size, self.entity_length], self.num_entity_span_classification_labels) config = LukeConfig( vocab_size=self.vocab_size, entity_vocab_size=self.entity_vocab_size, entity_emb_size=self.entity_emb_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, is_decoder=False, initializer_range=self.initializer_range, use_entity_aware_attention=self.use_entity_aware_attention, ) return ( config, input_ids, attention_mask, token_type_ids, entity_ids, entity_attention_mask, entity_token_type_ids, entity_position_ids, sequence_labels, entity_classification_labels, entity_pair_classification_labels, entity_span_classification_labels, )
def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, pytorch_dump_folder_path, model_size): # Load configuration defined in the metadata file with open(metadata_path) as metadata_file: metadata = json.load(metadata_file) config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"]) # Load in the weights from the checkpoint_path state_dict = torch.load(checkpoint_path, map_location="cpu")["module"] # Load the entity vocab file entity_vocab = load_original_entity_vocab(entity_vocab_path) # add an entry for [MASK2] entity_vocab["[MASK2]"] = max(entity_vocab.values()) + 1 config.entity_vocab_size += 1 tokenizer = XLMRobertaTokenizer.from_pretrained( metadata["model_config"]["bert_model_name"]) # Add special tokens to the token vocabulary for downstream tasks entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False) entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False) tokenizer.add_special_tokens( dict(additional_special_tokens=[entity_token_1, entity_token_2])) config.vocab_size += 2 print(f"Saving tokenizer to {pytorch_dump_folder_path}") tokenizer.save_pretrained(pytorch_dump_folder_path) with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "r") as f: tokenizer_config = json.load(f) tokenizer_config["tokenizer_class"] = "MLukeTokenizer" with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "w") as f: json.dump(tokenizer_config, f) with open( os.path.join( pytorch_dump_folder_path, MLukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f: json.dump(entity_vocab, f) tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path) # Initialize the embeddings of the special tokens ent_init_index = tokenizer.convert_tokens_to_ids(["@"])[0] ent2_init_index = tokenizer.convert_tokens_to_ids(["#"])[0] word_emb = state_dict["embeddings.word_embeddings.weight"] ent_emb = word_emb[ent_init_index].unsqueeze(0) ent2_emb = word_emb[ent2_init_index].unsqueeze(0) state_dict["embeddings.word_embeddings.weight"] = torch.cat( [word_emb, ent_emb, ent2_emb]) # add special tokens for 'entity_predictions.bias' for bias_name in ["lm_head.decoder.bias", "lm_head.bias"]: decoder_bias = state_dict[bias_name] ent_decoder_bias = decoder_bias[ent_init_index].unsqueeze(0) ent2_decoder_bias = decoder_bias[ent2_init_index].unsqueeze(0) state_dict[bias_name] = torch.cat( [decoder_bias, ent_decoder_bias, ent2_decoder_bias]) # Initialize the query layers of the entity-aware self-attention mechanism for layer_index in range(config.num_hidden_layers): for matrix_name in ["query.weight", "query.bias"]: prefix = f"encoder.layer.{layer_index}.attention.self." state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name] state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name] state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name] # Initialize the embedding of the [MASK2] entity using that of the [MASK] entity for downstream tasks entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"] entity_mask_emb = entity_emb[entity_vocab["[MASK]"]].unsqueeze(0) state_dict["entity_embeddings.entity_embeddings.weight"] = torch.cat( [entity_emb, entity_mask_emb]) # add [MASK2] for 'entity_predictions.bias' entity_prediction_bias = state_dict["entity_predictions.bias"] entity_mask_bias = entity_prediction_bias[ entity_vocab["[MASK]"]].unsqueeze(0) state_dict["entity_predictions.bias"] = torch.cat( [entity_prediction_bias, entity_mask_bias]) model = LukeForMaskedLM(config=config).eval() state_dict.pop("entity_predictions.decoder.weight") state_dict.pop("lm_head.decoder.weight") state_dict.pop("lm_head.decoder.bias") state_dict_for_hugging_face = OrderedDict() for key, value in state_dict.items(): if not (key.startswith("lm_head") or key.startswith("entity_predictions")): state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key] else: state_dict_for_hugging_face[key] = state_dict[key] missing_keys, unexpected_keys = model.load_state_dict( state_dict_for_hugging_face, strict=False) if set(unexpected_keys) != {"luke.embeddings.position_ids"}: raise ValueError(f"Unexpected unexpected_keys: {unexpected_keys}") if set(missing_keys) != { "lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight", }: raise ValueError(f"Unexpected missing_keys: {missing_keys}") model.tie_weights() assert (model.luke.embeddings.word_embeddings.weight == model.lm_head.decoder.weight).all() assert (model.luke.entity_embeddings.entity_embeddings.weight == model.entity_predictions.decoder.weight).all() # Check outputs tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification") text = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)." span = (0, 9) encoding = tokenizer(text, entity_spans=[span], return_tensors="pt") outputs = model(**encoding) # Verify word hidden states if model_size == "large": raise NotImplementedError else: # base expected_shape = torch.Size((1, 33, 768)) expected_slice = torch.tensor([[0.0892, 0.0596, -0.2819], [0.0134, 0.1199, 0.0573], [-0.0169, 0.0927, 0.0644]]) if not (outputs.last_hidden_state.shape == expected_shape): raise ValueError( f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}" ) if not torch.allclose( outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4): raise ValueError # Verify entity hidden states if model_size == "large": raise NotImplementedError else: # base expected_shape = torch.Size((1, 1, 768)) expected_slice = torch.tensor([[-0.1482, 0.0609, 0.0322]]) if not (outputs.entity_last_hidden_state.shape == expected_shape): raise ValueError( f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is" f" {expected_shape}") if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4): raise ValueError # Verify masked word/entity prediction tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path) text = "Tokyo is the capital of <mask>." span = (24, 30) encoding = tokenizer(text, entity_spans=[span], return_tensors="pt") outputs = model(**encoding) input_ids = encoding["input_ids"][0].tolist() mask_position_id = input_ids.index( tokenizer.convert_tokens_to_ids("<mask>")) predicted_id = outputs.logits[0][mask_position_id].argmax(dim=-1) assert "Japan" == tokenizer.decode(predicted_id) predicted_entity_id = outputs.entity_logits[0][0].argmax().item() multilingual_predicted_entities = [ entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id == predicted_entity_id ] assert [e for e in multilingual_predicted_entities if e.startswith("en:")][0] == "en:Japan" # Finally, save our PyTorch model and tokenizer print("Saving PyTorch model to {}".format(pytorch_dump_folder_path)) model.save_pretrained(pytorch_dump_folder_path)