""" Elliot Schumacher, Johns Hopkins University Created 3/19/19 """ from allennlp.modules.elmo import Elmo, batch_to_ids import torch options_file = "/Users/elliotschumacher/Dropbox/git/synonym_detection/resources/bilm/out_max/options.json" weight_file = "/Users/elliotschumacher/Dropbox/git/synonym_detection/resources/bilm/out_max/std-weights.hdf5" # Compute two different representation for each token. # Each representation is a linear weighted combination for the # 3 layers in ELMo (i.e., charcnn, the outputs of the two BiLSTM)) elmo = Elmo(options_file, weight_file, 1, dropout=0) # use batch_to_ids to convert sentences to character ids sentences = [['First', 'sentence', '.'], ['Another', '.'], ["The", "patient", "displayed", "signs", "of", "diabetes"]] character_ids = batch_to_ids(sentences) embeddings = elmo(character_ids) print(embeddings) elmo.train() # embeddings['elmo_representations'] is length two list of tensors. # Each element contains one layer of ELMo representations with shape # (2, 3, 1024). # 2 - the batch size # 3 - the sequence length of the batch # 1024 - the length of each ELMo vector
class ContextualControllerELMo(ControllerBase): def __init__( self, hidden_size, dropout, pretrained_embeddings_dir, dataset_name, fc_hidden_size=150, freeze_pretrained=True, learning_rate=0.001, layer_learning_rate: Optional[Dict[str, float]] = None, max_segment_size=None, # if None, process sentences independently max_span_size=10, model_name=None): self.hidden_size = hidden_size self.dropout = dropout self.freeze_pretrained = freeze_pretrained self.fc_hidden_size = fc_hidden_size self.max_span_size = max_span_size self.max_segment_size = max_segment_size self.learning_rate = learning_rate self.layer_learning_rate = layer_learning_rate if layer_learning_rate is not None else {} self.pretrained_embeddings_dir = pretrained_embeddings_dir self.embedder = Elmo( options_file=os.path.join(pretrained_embeddings_dir, "options.json"), weight_file=os.path.join(pretrained_embeddings_dir, "slovenian-elmo-weights.hdf5"), dropout=(0.0 if freeze_pretrained else dropout), num_output_representations=1, requires_grad=(not freeze_pretrained)).to(DEVICE) embedding_size = self.embedder.get_output_dim() self.context_encoder = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, batch_first=True, bidirectional=True).to(DEVICE) self.scorer = NeuralCoreferencePairScorer(num_features=(2 * hidden_size), hidden_size=fc_hidden_size, dropout=dropout).to(DEVICE) params_to_update = [{ "params": self.scorer.parameters(), "lr": self.layer_learning_rate.get("lr_scorer", self.learning_rate) }, { "params": self.context_encoder.parameters(), "lr": self.layer_learning_rate.get("lr_context_encoder", self.learning_rate) }] if not freeze_pretrained: params_to_update.append({ "params": self.embedder.parameters(), "lr": self.layer_learning_rate.get("lr_embedder", self.learning_rate) }) self.optimizer = optim.Adam(params_to_update, lr=self.learning_rate) super().__init__(learning_rate=learning_rate, dataset_name=dataset_name, model_name=model_name) logging.info( f"Initialized contextual ELMo-based model with name {self.model_name}." ) @property def model_base_dir(self): return "contextual_model_elmo" def train_mode(self): if not self.freeze_pretrained: self.embedder.train() self.context_encoder.train() self.scorer.train() def eval_mode(self): self.embedder.eval() self.context_encoder.eval() self.scorer.eval() def load_checkpoint(self): self.loaded_from_file = True self.context_encoder.load_state_dict( torch.load(os.path.join(self.path_model_dir, "context_encoder.th"), map_location=DEVICE)) self.scorer.load_state_dict( torch.load(os.path.join(self.path_model_dir, "scorer.th"), map_location=DEVICE)) path_to_embeddings = os.path.join(self.path_model_dir, "embeddings.th") if os.path.isfile(path_to_embeddings): logging.info( f"Loading fine-tuned ELMo weights from '{path_to_embeddings}'") self.embedder.load_state_dict( torch.load(path_to_embeddings, map_location=DEVICE)) @staticmethod def from_pretrained(model_dir): controller_config_path = os.path.join(model_dir, "controller_config.json") with open(controller_config_path, "r", encoding="utf-8") as f_config: pre_config = json.load(f_config) instance = ContextualControllerELMo(**pre_config) instance.load_checkpoint() return instance def save_pretrained(self, model_dir): if not os.path.exists(model_dir): os.makedirs(model_dir) # Write controller config (used for instantiation) controller_config_path = os.path.join(model_dir, "controller_config.json") with open(controller_config_path, "w", encoding="utf-8") as f_config: json.dump( { "hidden_size": self.hidden_size, "dropout": self.dropout, "pretrained_embeddings_dir": self.pretrained_embeddings_dir, "dataset_name": self.dataset_name, "fc_hidden_size": self.fc_hidden_size, "freeze_pretrained": self.freeze_pretrained, "learning_rate": self.learning_rate, "layer_learning_rate": self.layer_learning_rate, "max_segment_size": self.max_segment_size, "max_span_size": self.max_span_size, "model_name": self.model_name }, fp=f_config, indent=4) torch.save(self.context_encoder.state_dict(), os.path.join(self.path_model_dir, "context_encoder.th")) torch.save(self.scorer.state_dict(), os.path.join(self.path_model_dir, "scorer.th")) # Save fine-tuned ELMo embeddings only if they're not frozen if not self.freeze_pretrained: torch.save(self.embedder.state_dict(), os.path.join(self.path_model_dir, "embeddings.th")) def save_checkpoint(self): logging.warning( "save_checkpoint() is deprecated. Use save_pretrained() instead") self.save_pretrained(self.path_model_dir) def _prepare_doc(self, curr_doc: Document) -> Dict: """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since data inside same document does not get shuffled. """ ret = {} # By default, each sentence is its own segment, meaning sentences are processed independently if self.max_segment_size is None: def get_position(t): return t.sentence_index, t.position_in_sentence _encoded_segments = batch_to_ids(curr_doc.raw_sentences()) # Optionally, one can specify max_segment_size, in which case segments of tokens are processed independently else: def get_position(t): doc_position = t.position_in_document return doc_position // self.max_segment_size, doc_position % self.max_segment_size flattened_doc = list(chain(*curr_doc.raw_sentences())) num_segments = (len(flattened_doc) + self.max_segment_size - 1) // self.max_segment_size _encoded_segments = \ batch_to_ids([flattened_doc[idx_seg * self.max_segment_size: (idx_seg + 1) * self.max_segment_size] for idx_seg in range(num_segments)]) encoded_segments = [] # Convention: Add a PAD word ([0] * max_chars vector) at the end of each segment, for padding mentions for curr_sent in _encoded_segments: encoded_segments.append( torch.cat((curr_sent, torch.zeros( (1, ELMoCharacterMapper.max_word_length), dtype=torch.long)))) encoded_segments = torch.stack(encoded_segments) cluster_sets = [] mention_to_cluster_id = {} for i, curr_cluster in enumerate(curr_doc.clusters): cluster_sets.append(set(curr_cluster)) for mid in curr_cluster: mention_to_cluster_id[mid] = i all_candidate_data = [] for idx_head, (head_id, head_mention) in enumerate(curr_doc.mentions.items(), 1): gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]] # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`) candidates, candidate_data = [None], [] candidate_attention = [] correct_antecedents = [] curr_head_data = [[], []] num_head_words = 0 for curr_token in head_mention.tokens: idx_segment, idx_inside_segment = get_position(curr_token) curr_head_data[0].append(idx_segment) curr_head_data[1].append(idx_inside_segment) num_head_words += 1 if num_head_words > self.max_span_size: curr_head_data[0] = curr_head_data[0][:self.max_span_size] curr_head_data[1] = curr_head_data[1][:self.max_span_size] else: curr_head_data[0] += [curr_head_data[0][-1] ] * (self.max_span_size - num_head_words) curr_head_data[1] += [-1 ] * (self.max_span_size - num_head_words) head_attention = torch.ones((1, self.max_span_size), dtype=torch.bool) head_attention[0, num_head_words:] = False for idx_candidate, (cand_id, cand_mention) in enumerate( curr_doc.mentions.items(), start=1): if idx_candidate >= idx_head: break candidates.append(cand_id) # Maps tokens to positions inside segments (idx_seg, idx_inside_seg) for efficient indexing later curr_candidate_data = [[], []] num_candidate_words = 0 for curr_token in cand_mention.tokens: idx_segment, idx_inside_segment = get_position(curr_token) curr_candidate_data[0].append(idx_segment) curr_candidate_data[1].append(idx_inside_segment) num_candidate_words += 1 if num_candidate_words > self.max_span_size: curr_candidate_data[0] = curr_candidate_data[ 0][:self.max_span_size] curr_candidate_data[1] = curr_candidate_data[ 1][:self.max_span_size] else: # padding tokens index into the PAD token of the last segment curr_candidate_data[0] += [curr_candidate_data[0][-1]] * ( self.max_span_size - num_candidate_words) curr_candidate_data[1] += [-1] * (self.max_span_size - num_candidate_words) candidate_data.append(curr_candidate_data) curr_attention = torch.ones((1, self.max_span_size), dtype=torch.bool) curr_attention[0, num_candidate_words:] = False candidate_attention.append(curr_attention) is_coreferent = cand_id in gt_antecedent_ids if is_coreferent: correct_antecedents.append(idx_candidate) if len(correct_antecedents) == 0: correct_antecedents.append(0) candidate_attention = torch.cat( candidate_attention) if len(candidate_attention) > 0 else [] all_candidate_data.append({ "head_id": head_id, "head_data": torch.tensor([curr_head_data]), "head_attention": head_attention, "candidates": candidates, "candidate_data": torch.tensor(candidate_data), "candidate_attention": candidate_attention, "correct_antecedents": correct_antecedents }) ret["preprocessed_segments"] = encoded_segments ret["steps"] = all_candidate_data return ret def _train_doc(self, curr_doc, eval_mode=False): """ Trains/evaluates (if `eval_mode` is True) model on specific document. Returns predictions, loss and number of examples evaluated. """ if len(curr_doc.mentions) == 0: return {}, (0.0, 0) if not hasattr(curr_doc, "_cache_elmo"): curr_doc._cache_elmo = self._prepare_doc(curr_doc) cache = curr_doc._cache_elmo # type: Dict encoded_segments = cache["preprocessed_segments"] if self.freeze_pretrained: with torch.no_grad(): res = self.embedder(encoded_segments.to(DEVICE)) else: res = self.embedder(encoded_segments.to(DEVICE)) # Note: max_segment_size is either specified at instantiation or (the length of longest sentence + 1) embedded_segments = res["elmo_representations"][ 0] # [num_segments, max_segment_size, embedding_size] (lstm_segments, _) = self.context_encoder( embedded_segments ) # [num_segments, max_segment_size, 2 * hidden_size] doc_loss, n_examples = 0.0, len(cache["steps"]) preds = {} for curr_step in cache["steps"]: head_id = curr_step["head_id"] head_data = curr_step["head_data"] candidates = curr_step["candidates"] candidate_data = curr_step["candidate_data"] correct_antecedents = curr_step["correct_antecedents"] # Note: num_candidates includes dummy antecedent + actual candidates num_candidates = len(candidates) if num_candidates == 1: curr_pred = 0 else: idx_segment = candidate_data[:, 0, :] idx_in_segment = candidate_data[:, 1, :] # [num_candidates, max_span_size, embedding_size] candidate_data = lstm_segments[idx_segment, idx_in_segment] # [1, head_size, embedding_size] head_data = lstm_segments[head_data[:, 0, :], head_data[:, 1, :]] head_data = head_data.repeat((num_candidates - 1, 1, 1)) candidate_scores = self.scorer( candidate_data, head_data, curr_step["candidate_attention"], curr_step["head_attention"].repeat( (num_candidates - 1, 1))) # [1, num_candidates] candidate_scores = torch.cat( (torch.tensor([0.0], device=DEVICE), candidate_scores.flatten())).unsqueeze(0) curr_pred = torch.argmax(candidate_scores) doc_loss += self.loss( candidate_scores.repeat((len(correct_antecedents), 1)), torch.tensor(correct_antecedents, device=DEVICE)) # { antecedent: [mention(s)] } pair existing_refs = preds.get(candidates[int(curr_pred)], []) existing_refs.append(head_id) preds[candidates[int(curr_pred)]] = existing_refs if not eval_mode: doc_loss.backward() self.optimizer.step() self.optimizer.zero_grad() return preds, (float(doc_loss), n_examples)