def main(): # Load generated model file archive = load_archive(args.archive_path) model = archive.model finetuned_elmo_state_dict = model._contextualizer._elmo.state_dict() # Load ELMo options and weights file elmo = Elmo(args.options_file, args.weight_file, 1) original_elmo_state_dict = elmo.state_dict() # Get the average parameter shift in the token embedder. token_embedder_total_shift = 0.0 token_embedder_num_params = 0.0 for key, parameter in finetuned_elmo_state_dict.items(): if "token_embedder" in key: token_embedder_num_params += parameter.numel() token_embedder_total_shift += torch.abs( parameter - original_elmo_state_dict[key]).sum().item() logger.info("Average Shift (L1 distance) in token embedder: {}".format( token_embedder_total_shift / token_embedder_num_params)) # Get the average parameter shift in the first layer of the LSTM. layer_0_total_shift = 0.0 layer_0_num_params = 0.0 for key, parameter in finetuned_elmo_state_dict.items(): if "backward_layer_0" in key or "forward_layer_0" in key: layer_0_num_params += parameter.numel() layer_0_total_shift += torch.abs( parameter - original_elmo_state_dict[key]).sum().item() logger.info("Average Shift (L1 distance) in LSTM Layer 0: {}".format( layer_0_total_shift / layer_0_num_params)) # Get the average parameter shift in the second layer of the LSTM. layer_1_total_shift = 0.0 layer_1_num_params = 0.0 for key, parameter in finetuned_elmo_state_dict.items(): if "backward_layer_1" in key or "forward_layer_1" in key: layer_1_num_params += parameter.numel() layer_1_total_shift += torch.abs( parameter - original_elmo_state_dict[key]).sum().item() logger.info("Average Shift (L1 distance) in LSTM Layer 1: {}".format( layer_1_total_shift / layer_1_num_params)) # Print the scalar mix parameters of the fine-tuned model. normed_scalars = torch.nn.functional.softmax(torch.cat([ parameter for key, parameter in finetuned_elmo_state_dict.items() if "scalar_parameters" in key ]), dim=0) normed_scalars = torch.split(normed_scalars, split_size_or_sections=1) normed_scalars = [normed_scalar.item() for normed_scalar in normed_scalars] logger.info( "Normalized Scalar Mix of fine-tuned model: {}".format(normed_scalars)) # Print the gamma logger.info("Gamma of fine-tuned model: {}".format( finetuned_elmo_state_dict["scalar_mix_0.gamma"].item()))
class SentenceElmo(nn.Module): def __init__(self, options_file, weight_file, tokenizer, average_mod='mean', max_seq_length=128): super().__init__() assert average_mod in {'mean', 'max', 'last'} self.elmo = Elmo(options_file=options_file, weight_file=weight_file, num_output_representations=1, requires_grad=True) self.tokenizer = tokenizer self.average_mod = average_mod self.max_seq_length = max_seq_length def get_word_embedding_dimension(self) -> int: return self.elmo.get_output_dim() def forward(self, features): output = self.elmo(features['input_ids']) token_embeddings = output['elmo_representations'][0] features = {} if self.average_mod == 'mean': features['sentence_embedding'] = token_embeddings.mean(axis=1) elif self.average_mod == 'max': features['sentence_embedding'] = token_embeddings.max( axis=1).values else: last_token_indices = output['mask'].sum(axis=1) - 1 features['sentence_embedding'] = token_embeddings[ torch.arange(token_embeddings.shape[0]), last_token_indices, :] return features def tokenize(self, texts: List[str]): tokenized_texts = [ self.tokenizer.tokenize(text)[:self.max_seq_length] for text in texts ] input_ids = batch_to_ids(tokenized_texts) output = {'input_ids': input_ids} return output def save(self, output_path: str): torch.save(self.elmo.state_dict(), os.path.join(output_path, 'model.pth'))
class ContextualControllerELMo(ControllerBase): def __init__( self, hidden_size, dropout, pretrained_embeddings_dir, dataset_name, fc_hidden_size=150, freeze_pretrained=True, learning_rate=0.001, layer_learning_rate: Optional[Dict[str, float]] = None, max_segment_size=None, # if None, process sentences independently max_span_size=10, model_name=None): self.hidden_size = hidden_size self.dropout = dropout self.freeze_pretrained = freeze_pretrained self.fc_hidden_size = fc_hidden_size self.max_span_size = max_span_size self.max_segment_size = max_segment_size self.learning_rate = learning_rate self.layer_learning_rate = layer_learning_rate if layer_learning_rate is not None else {} self.pretrained_embeddings_dir = pretrained_embeddings_dir self.embedder = Elmo( options_file=os.path.join(pretrained_embeddings_dir, "options.json"), weight_file=os.path.join(pretrained_embeddings_dir, "slovenian-elmo-weights.hdf5"), dropout=(0.0 if freeze_pretrained else dropout), num_output_representations=1, requires_grad=(not freeze_pretrained)).to(DEVICE) embedding_size = self.embedder.get_output_dim() self.context_encoder = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, bidirectional=True).to(DEVICE) self.scorer = NeuralCoreferencePairScorer(num_features=(2 * hidden_size), hidden_size=fc_hidden_size, dropout=dropout).to(DEVICE) params_to_update = [{ "params": self.scorer.parameters(), "lr": self.layer_learning_rate.get("lr_scorer", self.learning_rate) }, { "params": self.context_encoder.parameters(), "lr": self.layer_learning_rate.get("lr_context_encoder", self.learning_rate) }] if not freeze_pretrained: params_to_update.append({ "params": self.embedder.parameters(), "lr": self.layer_learning_rate.get("lr_embedder", self.learning_rate) }) self.optimizer = optim.Adam(params_to_update, lr=self.learning_rate) super().__init__(learning_rate=learning_rate, dataset_name=dataset_name, model_name=model_name) logging.info( f"Initialized contextual ELMo-based model with name {self.model_name}." ) @property def model_base_dir(self): return "contextual_model_elmo" def train_mode(self): if not self.freeze_pretrained: self.embedder.train() self.context_encoder.train() self.scorer.train() def eval_mode(self): self.embedder.eval() self.context_encoder.eval() self.scorer.eval() def load_checkpoint(self): self.loaded_from_file = True self.context_encoder.load_state_dict( torch.load(os.path.join(self.path_model_dir, "context_encoder.th"), map_location=DEVICE)) self.scorer.load_state_dict( torch.load(os.path.join(self.path_model_dir, "scorer.th"), map_location=DEVICE)) path_to_embeddings = os.path.join(self.path_model_dir, "embeddings.th") if os.path.isfile(path_to_embeddings): logging.info( f"Loading fine-tuned ELMo weights from '{path_to_embeddings}'") self.embedder.load_state_dict( torch.load(path_to_embeddings, map_location=DEVICE)) @staticmethod def from_pretrained(model_dir): controller_config_path = os.path.join(model_dir, "controller_config.json") with open(controller_config_path, "r", encoding="utf-8") as f_config: pre_config = json.load(f_config) instance = ContextualControllerELMo(**pre_config) instance.load_checkpoint() return instance def save_pretrained(self, model_dir): if not os.path.exists(model_dir): os.makedirs(model_dir) # Write controller config (used for instantiation) controller_config_path = os.path.join(model_dir, "controller_config.json") with open(controller_config_path, "w", encoding="utf-8") as f_config: json.dump( { "hidden_size": self.hidden_size, "dropout": self.dropout, "pretrained_embeddings_dir": self.pretrained_embeddings_dir, "dataset_name": self.dataset_name, "fc_hidden_size": self.fc_hidden_size, "freeze_pretrained": self.freeze_pretrained, "learning_rate": self.learning_rate, "layer_learning_rate": self.layer_learning_rate, "max_segment_size": self.max_segment_size, "max_span_size": self.max_span_size, "model_name": self.model_name }, fp=f_config, indent=4) torch.save(self.context_encoder.state_dict(), os.path.join(self.path_model_dir, "context_encoder.th")) torch.save(self.scorer.state_dict(), os.path.join(self.path_model_dir, "scorer.th")) # Save fine-tuned ELMo embeddings only if they're not frozen if not self.freeze_pretrained: torch.save(self.embedder.state_dict(), os.path.join(self.path_model_dir, "embeddings.th")) def save_checkpoint(self): logging.warning( "save_checkpoint() is deprecated. Use save_pretrained() instead") self.save_pretrained(self.path_model_dir) def _prepare_doc(self, curr_doc: Document) -> Dict: """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since data inside same document does not get shuffled. """ ret = {} # By default, each sentence is its own segment, meaning sentences are processed independently if self.max_segment_size is None: def get_position(t): return t.sentence_index, t.position_in_sentence _encoded_segments = batch_to_ids(curr_doc.raw_sentences()) # Optionally, one can specify max_segment_size, in which case segments of tokens are processed independently else: def get_position(t): doc_position = t.position_in_document return doc_position // self.max_segment_size, doc_position % self.max_segment_size flattened_doc = list(chain(*curr_doc.raw_sentences())) num_segments = (len(flattened_doc) + self.max_segment_size - 1) // self.max_segment_size _encoded_segments = \ batch_to_ids([flattened_doc[idx_seg * self.max_segment_size: (idx_seg + 1) * self.max_segment_size] for idx_seg in range(num_segments)]) encoded_segments = [] # Convention: Add a PAD word ([0] * max_chars vector) at the end of each segment, for padding mentions for curr_sent in _encoded_segments: encoded_segments.append( torch.cat((curr_sent, torch.zeros( (1, ELMoCharacterMapper.max_word_length), dtype=torch.long)))) encoded_segments = torch.stack(encoded_segments) cluster_sets = [] mention_to_cluster_id = {} for i, curr_cluster in enumerate(curr_doc.clusters): cluster_sets.append(set(curr_cluster)) for mid in curr_cluster: mention_to_cluster_id[mid] = i all_candidate_data = [] for idx_head, (head_id, head_mention) in enumerate(curr_doc.mentions.items(), 1): gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]] # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`) candidates, candidate_data = [None], [] candidate_attention = [] correct_antecedents = [] curr_head_data = [[], []] num_head_words = 0 for curr_token in head_mention.tokens: idx_segment, idx_inside_segment = get_position(curr_token) curr_head_data[0].append(idx_segment) curr_head_data[1].append(idx_inside_segment) num_head_words += 1 if num_head_words > self.max_span_size: curr_head_data[0] = curr_head_data[0][:self.max_span_size] curr_head_data[1] = curr_head_data[1][:self.max_span_size] else: curr_head_data[0] += [curr_head_data[0][-1] ] * (self.max_span_size - num_head_words) curr_head_data[1] += [-1 ] * (self.max_span_size - num_head_words) head_attention = torch.ones((1, self.max_span_size), dtype=torch.bool) head_attention[0, num_head_words:] = False for idx_candidate, (cand_id, cand_mention) in enumerate( curr_doc.mentions.items(), start=1): if idx_candidate >= idx_head: break candidates.append(cand_id) # Maps tokens to positions inside segments (idx_seg, idx_inside_seg) for efficient indexing later curr_candidate_data = [[], []] num_candidate_words = 0 for curr_token in cand_mention.tokens: idx_segment, idx_inside_segment = get_position(curr_token) curr_candidate_data[0].append(idx_segment) curr_candidate_data[1].append(idx_inside_segment) num_candidate_words += 1 if num_candidate_words > self.max_span_size: curr_candidate_data[0] = curr_candidate_data[ 0][:self.max_span_size] curr_candidate_data[1] = curr_candidate_data[ 1][:self.max_span_size] else: # padding tokens index into the PAD token of the last segment curr_candidate_data[0] += [curr_candidate_data[0][-1]] * ( self.max_span_size - num_candidate_words) curr_candidate_data[1] += [-1] * (self.max_span_size - num_candidate_words) candidate_data.append(curr_candidate_data) curr_attention = torch.ones((1, self.max_span_size), dtype=torch.bool) curr_attention[0, num_candidate_words:] = False candidate_attention.append(curr_attention) is_coreferent = cand_id in gt_antecedent_ids if is_coreferent: correct_antecedents.append(idx_candidate) if len(correct_antecedents) == 0: correct_antecedents.append(0) candidate_attention = torch.cat( candidate_attention) if len(candidate_attention) > 0 else [] all_candidate_data.append({ "head_id": head_id, "head_data": torch.tensor([curr_head_data]), "head_attention": head_attention, "candidates": candidates, "candidate_data": torch.tensor(candidate_data), "candidate_attention": candidate_attention, "correct_antecedents": correct_antecedents }) ret["preprocessed_segments"] = encoded_segments ret["steps"] = all_candidate_data return ret def _train_doc(self, curr_doc, eval_mode=False): """ Trains/evaluates (if `eval_mode` is True) model on specific document. Returns predictions, loss and number of examples evaluated. """ if len(curr_doc.mentions) == 0: return {}, (0.0, 0) if not hasattr(curr_doc, "_cache_elmo"): curr_doc._cache_elmo = self._prepare_doc(curr_doc) cache = curr_doc._cache_elmo # type: Dict encoded_segments = cache["preprocessed_segments"] if self.freeze_pretrained: with torch.no_grad(): res = self.embedder(encoded_segments.to(DEVICE)) else: res = self.embedder(encoded_segments.to(DEVICE)) # Note: max_segment_size is either specified at instantiation or (the length of longest sentence + 1) embedded_segments = res["elmo_representations"][ 0] # [num_segments, max_segment_size, embedding_size] (lstm_segments, _) = self.context_encoder( embedded_segments ) # [num_segments, max_segment_size, 2 * hidden_size] doc_loss, n_examples = 0.0, len(cache["steps"]) preds = {} for curr_step in cache["steps"]: head_id = curr_step["head_id"] head_data = curr_step["head_data"] candidates = curr_step["candidates"] candidate_data = curr_step["candidate_data"] correct_antecedents = curr_step["correct_antecedents"] # Note: num_candidates includes dummy antecedent + actual candidates num_candidates = len(candidates) if num_candidates == 1: curr_pred = 0 else: idx_segment = candidate_data[:, 0, :] idx_in_segment = candidate_data[:, 1, :] # [num_candidates, max_span_size, embedding_size] candidate_data = lstm_segments[idx_segment, idx_in_segment] # [1, head_size, embedding_size] head_data = lstm_segments[head_data[:, 0, :], head_data[:, 1, :]] head_data = head_data.repeat((num_candidates - 1, 1, 1)) candidate_scores = self.scorer( candidate_data, head_data, curr_step["candidate_attention"], curr_step["head_attention"].repeat( (num_candidates - 1, 1))) # [1, num_candidates] candidate_scores = torch.cat( (torch.tensor([0.0], device=DEVICE), candidate_scores.flatten())).unsqueeze(0) curr_pred = torch.argmax(candidate_scores) doc_loss += self.loss( candidate_scores.repeat((len(correct_antecedents), 1)), torch.tensor(correct_antecedents, device=DEVICE)) # { antecedent: [mention(s)] } pair existing_refs = preds.get(candidates[int(curr_pred)], []) existing_refs.append(head_id) preds[candidates[int(curr_pred)]] = existing_refs if not eval_mode: doc_loss.backward() self.optimizer.step() self.optimizer.zero_grad() return preds, (float(doc_loss), n_examples)