def _unpack_model_batch_prediction(self, batch, coerce_tree=False) -> np.ndarray: """ Interpret prediction result per batch coerce_tree = True if you want to ensure that the output forms a tree """ out_dict = self.model(**batch) pred_matrix = out_dict["pred_matrix"] batch_interpretation = [] for es in range(len(pred_matrix)): essay_pred = tonp(pred_matrix[es]) # decoding using simple argmax essay_pred = np.argmax(essay_pred, axis=-1) dist_interpretation = [] for i in range(len(essay_pred)): dist_interpretation.append(essay_pred[i] - i) # check if the output is a tree rep = TreeBuilder(dist_interpretation) if (not rep.is_tree()) and (coerce_tree == True): # run MINIMUM spanning tree attn_matrix = tonp(pred_matrix[es]) attn_matrix = np.array(attn_matrix) rank_order = get_rank_order(attn_matrix) dist_interpretation = run_MST( rank_order, rank_order, verdict="min" ) # --> use rank as the weight, "minimum" spanning tree, lower_rank number in rank is better # add the decoding result to the batch result batch_interpretation.append(dist_interpretation) return batch_interpretation
def _unpack_gold_batch_prediction(self, batch: np.ndarray) -> List: """ Only use predictions without padding Args: batch (torch.Tensor): prediction in batch Returns: List """ output_linking = [] output_node_labelling = [] batch_rel_dists = tonp(batch["rel_dists"]) batch_component_labels = tonp(batch["component_labels"]) seq_len = batch["seq_len"] for b in range(len(batch_rel_dists)): rel_dists_gold = batch_rel_dists[b][:seq_len[b]].tolist() rel_dists_gold = [self.dist_idx_to_dist(x) for x in rel_dists_gold] component_labels_gold = batch_component_labels[ b][:seq_len[b]].tolist() component_labels_gold = [ self.component_idx_to_label(x) for x in component_labels_gold ] output_linking.append(rel_dists_gold) output_node_labelling.append(component_labels_gold) return output_linking, output_node_labelling
def predict(self, ds: Iterable[Instance], coerce_tree=False) -> np.ndarray: """ Generate prediction result coerce_tree = True if we want to make sure that the prediction forms a tree """ pred_generator = self.iterator(ds, num_epochs=1, shuffle=False) self.model.eval() pred_generator_tqdm = tqdm( pred_generator, total=self.iterator.get_num_batches(ds) ) # what if the the valid/test data contain label that does not exist in the training data --> workaround for the vocab preds = [] golds = [] with torch.no_grad(): for batch in pred_generator_tqdm: batch = nn_util.move_to_device(batch, self.cuda_device) preds.extend( self._unpack_model_batch_prediction( batch, coerce_tree=coerce_tree)) golds.extend( self._unpack_gold_batch_prediction( tonp(batch["rel_dists"]), batch["seq_len"])) return preds, golds
def _unpack_model_batch_prediction(self, batch, coerce_tree=False) -> np.ndarray: """ Interpret prediction result per batch coerce_tree = True if we want to make sure that the predictions form a tree (using MST (min or max) algorithm) """ out_dict = self.model(**batch) pred_linking_softmax = tonp(out_dict["pred_linking_softmax"]) pred_node_labelling_softmax = tonp( out_dict["pred_node_labelling_softmax"]) linking_preds = [] node_labelling_preds = [] for es in range(len(pred_linking_softmax)): essay_linking = [] essay_labelling = [] max_seq_len = batch["seq_len"][es] # simple decoding using argmax for s in range( max_seq_len ): # iterate each sentence in the essay, s is the index of the current sentence # perform constrained argmax for linking curr_link_softmax = pred_linking_softmax[es][s] ranked_pred = [ i for i in reversed( sorted(enumerate(curr_link_softmax), key=lambda x: x[1])) ] for i in range(len(ranked_pred)): tmp_dist = self.dist_idx_to_dist(ranked_pred[i][0]) if 0 <= tmp_dist + s <= max_seq_len - 1: pred_dist = tmp_dist break # argmax for labelling curr_label_softmax = pred_node_labelling_softmax[es][s] pred_idx = np.argmax(curr_label_softmax) pred_label = self.component_idx_to_label(pred_idx) # essay-level result essay_linking.append(pred_dist) essay_labelling.append(pred_label) # check if the output is tree rep = TreeBuilder(essay_linking) if (not rep.is_tree()) and (coerce_tree == True): attn_matrix = [ ] # element [i,j] denotes the probability of sentence i connects to sentence j (j as the target) for s in range( max_seq_len ): # iterate each sentence in the essay, s is the index of the current sentence curr_pred = pred_linking_softmax[es][s] # get the prediction to each possible target sentence in the text row_pred = [0] * max_seq_len for i in range(len(curr_pred)): temp_dist = self.dist_idx_to_dist(i) value = curr_pred[i] if 0 <= temp_dist + s <= max_seq_len - 1: row_pred[temp_dist + s] = value attn_matrix.append(row_pred) # run MAXIMUM spanning tree attn_matrix = np.array(attn_matrix) rank_order = get_rank_order(attn_matrix) essay_linking = run_MST( rank_order, attn_matrix, verdict="max" ) # --> use the softmax probability as the weight, we run the maximum spanning tree here because higher probability means better # batch-level result linking_preds.append(essay_linking) node_labelling_preds.append(essay_labelling) return linking_preds, node_labelling_preds
def _unpack_model_batch_prediction(self, batch, coerce_tree=False) -> np.ndarray: """ Interpret prediction result per batch """ out_dict = self.model(**batch) pred_softmax = tonp(out_dict["pred_softmax"]) # print("seq len", batch["seq_len"]) # print(pred_softmax.shape) batch_interpretation = [] for es in range(len(pred_softmax)): essay_interpretation = [] max_seq_len = batch["seq_len"][es] # simple decoding using argmax for s in range( max_seq_len ): # iterate each sentence in the essay, s is the index of the current sentence curr_pred = pred_softmax[es][s] # perform constrained argmax ranked_pred = [ i for i in reversed( sorted(enumerate(curr_pred), key=lambda x: x[1])) ] # print(ranked_pred) for i in range(len(ranked_pred)): tmp_dist = self.dist_idx_to_dist(ranked_pred[i][0]) # print(tmp_dist, tmp_dist+s) # input() if 0 <= tmp_dist + s <= max_seq_len - 1: pred_dist = tmp_dist break essay_interpretation.append(pred_dist) # check if the output is tree rep = TreeBuilder(essay_interpretation) if (not rep.is_tree()) and (coerce_tree == True): attn_matrix = [ ] # element [i,j] denotes the probability of sentence i connects to sentence j (j as the target) for s in range( max_seq_len ): # iterate each sentence in the essay, s is the index of the current sentence curr_pred = pred_softmax[es][s] # get the prediction to each possible target sentence in the text row_pred = [0] * max_seq_len for i in range(len(curr_pred)): temp_dist = self.dist_idx_to_dist(i) value = curr_pred[i] if 0 <= temp_dist + s <= max_seq_len - 1: row_pred[temp_dist + s] = value attn_matrix.append(row_pred) # run MAXIMUM spanning tree attn_matrix = np.array(attn_matrix) rank_order = get_rank_order(attn_matrix) essay_interpretation = run_MST( rank_order, attn_matrix, verdict="max" ) # --> use the softmax probability as the weight, we run the maximum spanning tree here because higher probability means better batch_interpretation.append(essay_interpretation) return batch_interpretation