i2supertag = [model.vocab.get_token_from_index(i, namespace=formalism+"_supertag_labels") for i in range(model.vocab.get_vocab_size(formalism+"_supertag_labels"))] lexlabel2i = { model.vocab.get_token_from_index(i, namespace=formalism+"_lex_labels") : i for i in range(model.vocab.get_vocab_size(formalism+"_lex_labels"))} def dump_tags(score, fragment, type): if type == "_": #\bot x = "NULL" else: x = fragment.replace(" "," ").replace(" ","__ALTO_WS__")+"--TYPE--"+str(type).replace(" ","") return x+"|"+str(round(score,5)) top_k_labels = 30 top_k_supertags = 15 bot_id = model.vocab.get_token_index(AMSentence.get_bottom_supertag(),namespace=formalism+"_supertag_labels") with zipfile.ZipFile(args.output_path + "/scores.zip","w",compression=zipfile.ZIP_DEFLATED, compresslevel=7) as myzip: tagprobs = [] modified_conll_sentences = [] with myzip.open("opProbs.txt","w") as fp: for sentence_id,pred in enumerate(predictions): attributes = pred["attributes"] all_supertag_scores = F.log_softmax(torch.from_numpy(pred["supertag_scores"]),1) #shape (sent length, num supertags) top_k_supertag_indices = torch.argsort(all_supertag_scores, descending=True, dim=1)[:, :top_k_supertags].numpy() all_supertag_scores = all_supertag_scores.numpy() edge_scores = np.transpose(pred["edge_existence_scores"],[1,0]) #shape (sent len+1 (from), sent len+1 (to)) mask = 1e9*np.eye(edge_scores.shape[0])
def prepare_for_ftd(self, output_dict: Dict[str, torch.Tensor]): """ This function does not perform the decoding but only prepares it. Therefore, we take the result of forward and perform the following steps (for each sentence in batch): - remove padding - identify the root of the sentence, group other root-candidates under the proper root - collect a selection of supertags to speed up computation (top k selection is done later) :param output_dict: result of forward :return: output_dict with the following keys added: - lexlabels: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root) - supertags: nested list: contains for each sentence, for each word the most likely lexical label (w/o artificial root) """ t0 = time() best_supertags = output_dict.pop( "best_supertags").cpu().detach().numpy() supertag_scores = output_dict.pop( "supertag_scores") # shape (batch_size, seq_len, num supertags) full_label_logits = output_dict.pop("full_label_logits").cpu().detach( ).numpy() #shape (batch size, seq len, seq len, num edge labels) edge_existence_scores = output_dict.pop( "edge_existence_scores").cpu().detach().numpy( ) #shape (batch size, seq len, seq len, num edge labels) k = 10 if self.validation_evaluator: #retrieve k supertags from validation evaluator. if isinstance(self.validation_evaluator.predictor, AMconllPredictor): k = self.validation_evaluator.predictor.k k += 10 # perhaps there are some ill-formed supertags, make that very unlikely that there are not enough left after filtering. top_k_supertags = Supertagger.top_k_supertags( supertag_scores, k).cpu().detach().numpy() # shape (batch_size, seq_len, k) supertag_scores = supertag_scores.cpu().detach().numpy() lexlabels = output_dict.pop( "lexlabels").cpu().detach().numpy() #shape (batch_size, seq_len) heads = output_dict.pop("heads") heads_cpu = heads.cpu().detach().numpy() mask = output_dict.pop("mask") edge_label_logits = output_dict.pop("label_logits").cpu().detach( ).numpy() # shape (batch_size, seq_len, num edge labels) encoded_text_parsing = output_dict.pop("encoded_text_parsing") output_dict.pop("encoded_text_tagging") #don't need that lengths = get_lengths_from_binary_sequence_mask(mask) #here we collect things, in the end we will have one entry for each sentence: all_edge_label_logits = [] all_supertags = [] head_indices = [] roots = [] all_predicted_lex_labels = [] all_full_label_logits = [] all_edge_existence_scores = [] all_supertag_scores = [] #we need the following to identify the root root_edge_label_id = self.vocab.get_token_index("ROOT", namespace=self.name + "_head_tags") bot_id = self.vocab.get_token_index(AMSentence.get_bottom_supertag(), namespace=self.name + "_supertag_labels") for i, length in enumerate(lengths): instance_heads_cpu = list(heads_cpu[i, 1:length]) #Postprocess heads and find root of sentence: instance_heads_cpu, root = find_root( instance_heads_cpu, best_supertags[i, 1:length], edge_label_logits[i, 1:length, :], root_edge_label_id, bot_id, modify=True) roots.append(root) #apply changes to instance_heads tensor: instance_heads = heads[i, :] for j, x in enumerate(instance_heads_cpu): instance_heads[j + 1] = torch.tensor( x ) #+1 because we removed the first position from instance_heads_cpu # re-calculate edge label logits since heads might have changed: label_logits = self.edge_model.label_scores( encoded_text_parsing[i].unsqueeze(0), instance_heads.unsqueeze(0)).squeeze(0).detach().cpu().numpy() #(un)squeeze: fake batch dimension all_edge_label_logits.append(label_logits[1:length, :]) all_full_label_logits.append( full_label_logits[i, :length, :length, :]) all_edge_existence_scores.append( edge_existence_scores[i, :length, :length]) #calculate supertags for this sentence: all_supertag_scores.append(supertag_scores[ i, 1:length, :]) #new shape (sent length, num supertags) supertags_for_this_sentence = [] for word in range(1, length): supertags_for_this_word = [] for top_k in top_k_supertags[i, word]: fragment, typ = AMSentence.split_supertag( self.vocab.get_token_from_index(top_k, namespace=self.name + "_supertag_labels")) score = supertag_scores[i, word, top_k] supertags_for_this_word.append((score, fragment, typ)) if bot_id not in top_k_supertags[ i, word]: #\bot is not in the top k, but we have to add it anyway in order for the decoder to work properly. fragment, typ = AMSentence.split_supertag( AMSentence.get_bottom_supertag()) supertags_for_this_word.append( (supertag_scores[i, word, bot_id], fragment, typ)) supertags_for_this_sentence.append(supertags_for_this_word) all_supertags.append(supertags_for_this_sentence) all_predicted_lex_labels.append([ self.vocab.get_token_from_index(label, namespace=self.name + "_lex_labels") for label in lexlabels[i, 1:length] ]) head_indices.append(instance_heads_cpu) t1 = time() normalized_diff = (t1 - t0) / len(lengths) output_dict["normalized_prepare_ftd_time"] = [ normalized_diff for _ in range(len(lengths)) ] output_dict["lexlabels"] = all_predicted_lex_labels output_dict["supertags"] = all_supertags output_dict["root"] = roots output_dict["label_logits"] = all_edge_label_logits output_dict["predicted_heads"] = head_indices output_dict["full_label_logits"] = all_full_label_logits output_dict["edge_existence_scores"] = all_edge_existence_scores output_dict["supertag_scores"] = all_supertag_scores return output_dict
for i in range(model.vocab.get_vocab_size(formalism + "_lex_labels")) } def dump_tags(score, fragment, type): if type == "_": #\bot x = "NULL" else: x = fragment.replace(" ", " ").replace( " ", "__ALTO_WS__") + "--TYPE--" + str(type).replace(" ", "") return x + "|" + str(round(score, 5)) top_k_labels = 30 top_k_supertags = 15 bot_id = model.vocab.get_token_index(AMSentence.get_bottom_supertag(), namespace=formalism + "_supertag_labels") with zipfile.ZipFile(args.output_path + "/scores.zip", "w", compression=zipfile.ZIP_DEFLATED, compresslevel=7) as myzip: tagprobs = [] modified_conll_sentences = [] with myzip.open("opProbs.txt", "w") as fp: for sentence_id, pred in enumerate(predictions): attributes = pred["attributes"] all_supertag_scores = F.log_softmax( torch.from_numpy(pred["supertag_scores"]), 1) #shape (sent length, num supertags)