Ejemplo n.º 1
0
    def _unpack_model_batch_prediction(self,
                                       batch,
                                       coerce_tree=False) -> np.ndarray:
        """
        Interpret prediction result per batch
        coerce_tree = True if you want to ensure that the output forms a tree
        """
        out_dict = self.model(**batch)
        pred_matrix = out_dict["pred_matrix"]

        batch_interpretation = []
        for es in range(len(pred_matrix)):
            essay_pred = tonp(pred_matrix[es])

            # decoding using simple argmax
            essay_pred = np.argmax(essay_pred, axis=-1)
            dist_interpretation = []
            for i in range(len(essay_pred)):
                dist_interpretation.append(essay_pred[i] - i)

            # check if the output is a tree
            rep = TreeBuilder(dist_interpretation)
            if (not rep.is_tree()) and (coerce_tree == True):
                # run MINIMUM spanning tree
                attn_matrix = tonp(pred_matrix[es])
                attn_matrix = np.array(attn_matrix)
                rank_order = get_rank_order(attn_matrix)
                dist_interpretation = run_MST(
                    rank_order, rank_order, verdict="min"
                )  # --> use rank as the weight, "minimum" spanning tree, lower_rank number in rank is better

            # add the decoding result to the batch result
            batch_interpretation.append(dist_interpretation)
        return batch_interpretation
Ejemplo n.º 2
0
    def _unpack_gold_batch_prediction(self, batch: np.ndarray) -> List:
        """
        Only use predictions without padding

        Args:
            batch (torch.Tensor): prediction in batch

        Returns:
            List
        """
        output_linking = []
        output_node_labelling = []

        batch_rel_dists = tonp(batch["rel_dists"])
        batch_component_labels = tonp(batch["component_labels"])
        seq_len = batch["seq_len"]

        for b in range(len(batch_rel_dists)):
            rel_dists_gold = batch_rel_dists[b][:seq_len[b]].tolist()
            rel_dists_gold = [self.dist_idx_to_dist(x) for x in rel_dists_gold]

            component_labels_gold = batch_component_labels[
                b][:seq_len[b]].tolist()
            component_labels_gold = [
                self.component_idx_to_label(x) for x in component_labels_gold
            ]

            output_linking.append(rel_dists_gold)
            output_node_labelling.append(component_labels_gold)
        return output_linking, output_node_labelling
Ejemplo n.º 3
0
    def predict(self, ds: Iterable[Instance], coerce_tree=False) -> np.ndarray:
        """
        Generate prediction result
        coerce_tree = True if we want to make sure that the prediction forms a tree
        """
        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)
        self.model.eval()
        pred_generator_tqdm = tqdm(
            pred_generator, total=self.iterator.get_num_batches(ds)
        )  # what if the the valid/test data contain label that does not exist in the training data --> workaround for the vocab
        preds = []
        golds = []
        with torch.no_grad():
            for batch in pred_generator_tqdm:
                batch = nn_util.move_to_device(batch, self.cuda_device)
                preds.extend(
                    self._unpack_model_batch_prediction(
                        batch, coerce_tree=coerce_tree))
                golds.extend(
                    self._unpack_gold_batch_prediction(
                        tonp(batch["rel_dists"]), batch["seq_len"]))

        return preds, golds
Ejemplo n.º 4
0
    def _unpack_model_batch_prediction(self,
                                       batch,
                                       coerce_tree=False) -> np.ndarray:
        """
        Interpret prediction result per batch
        coerce_tree = True if we want to make sure that the predictions form a tree (using MST (min or max) algorithm)
        """
        out_dict = self.model(**batch)
        pred_linking_softmax = tonp(out_dict["pred_linking_softmax"])
        pred_node_labelling_softmax = tonp(
            out_dict["pred_node_labelling_softmax"])

        linking_preds = []
        node_labelling_preds = []
        for es in range(len(pred_linking_softmax)):
            essay_linking = []
            essay_labelling = []
            max_seq_len = batch["seq_len"][es]

            # simple decoding using argmax
            for s in range(
                    max_seq_len
            ):  # iterate each sentence in the essay, s is the index of the current sentence
                # perform constrained argmax for linking
                curr_link_softmax = pred_linking_softmax[es][s]
                ranked_pred = [
                    i for i in reversed(
                        sorted(enumerate(curr_link_softmax),
                               key=lambda x: x[1]))
                ]
                for i in range(len(ranked_pred)):
                    tmp_dist = self.dist_idx_to_dist(ranked_pred[i][0])
                    if 0 <= tmp_dist + s <= max_seq_len - 1:
                        pred_dist = tmp_dist
                        break

                # argmax for labelling
                curr_label_softmax = pred_node_labelling_softmax[es][s]
                pred_idx = np.argmax(curr_label_softmax)
                pred_label = self.component_idx_to_label(pred_idx)

                # essay-level result
                essay_linking.append(pred_dist)
                essay_labelling.append(pred_label)

            # check if the output is tree
            rep = TreeBuilder(essay_linking)
            if (not rep.is_tree()) and (coerce_tree == True):
                attn_matrix = [
                ]  # element [i,j] denotes the probability of sentence i connects to sentence j (j as the target)
                for s in range(
                        max_seq_len
                ):  # iterate each sentence in the essay, s is the index of the current sentence
                    curr_pred = pred_linking_softmax[es][s]

                    # get the prediction to each possible target sentence in the text
                    row_pred = [0] * max_seq_len
                    for i in range(len(curr_pred)):
                        temp_dist = self.dist_idx_to_dist(i)
                        value = curr_pred[i]
                        if 0 <= temp_dist + s <= max_seq_len - 1:
                            row_pred[temp_dist + s] = value

                    attn_matrix.append(row_pred)

                # run MAXIMUM spanning tree
                attn_matrix = np.array(attn_matrix)
                rank_order = get_rank_order(attn_matrix)
                essay_linking = run_MST(
                    rank_order, attn_matrix, verdict="max"
                )  # --> use the softmax probability as the weight, we run the maximum spanning tree here because higher probability means better

            # batch-level result
            linking_preds.append(essay_linking)
            node_labelling_preds.append(essay_labelling)

        return linking_preds, node_labelling_preds
Ejemplo n.º 5
0
    def _unpack_model_batch_prediction(self,
                                       batch,
                                       coerce_tree=False) -> np.ndarray:
        """
        Interpret prediction result per batch
        """
        out_dict = self.model(**batch)
        pred_softmax = tonp(out_dict["pred_softmax"])
        # print("seq len", batch["seq_len"])
        # print(pred_softmax.shape)

        batch_interpretation = []
        for es in range(len(pred_softmax)):

            essay_interpretation = []
            max_seq_len = batch["seq_len"][es]

            # simple decoding using argmax
            for s in range(
                    max_seq_len
            ):  # iterate each sentence in the essay, s is the index of the current sentence
                curr_pred = pred_softmax[es][s]

                # perform constrained argmax
                ranked_pred = [
                    i for i in reversed(
                        sorted(enumerate(curr_pred), key=lambda x: x[1]))
                ]
                # print(ranked_pred)
                for i in range(len(ranked_pred)):
                    tmp_dist = self.dist_idx_to_dist(ranked_pred[i][0])
                    # print(tmp_dist, tmp_dist+s)
                    # input()
                    if 0 <= tmp_dist + s <= max_seq_len - 1:
                        pred_dist = tmp_dist
                        break

                essay_interpretation.append(pred_dist)

            # check if the output is tree
            rep = TreeBuilder(essay_interpretation)
            if (not rep.is_tree()) and (coerce_tree == True):
                attn_matrix = [
                ]  # element [i,j] denotes the probability of sentence i connects to sentence j (j as the target)
                for s in range(
                        max_seq_len
                ):  # iterate each sentence in the essay, s is the index of the current sentence
                    curr_pred = pred_softmax[es][s]

                    # get the prediction to each possible target sentence in the text
                    row_pred = [0] * max_seq_len
                    for i in range(len(curr_pred)):
                        temp_dist = self.dist_idx_to_dist(i)
                        value = curr_pred[i]
                        if 0 <= temp_dist + s <= max_seq_len - 1:
                            row_pred[temp_dist + s] = value

                    attn_matrix.append(row_pred)

                # run MAXIMUM spanning tree
                attn_matrix = np.array(attn_matrix)
                rank_order = get_rank_order(attn_matrix)
                essay_interpretation = run_MST(
                    rank_order, attn_matrix, verdict="max"
                )  # --> use the softmax probability as the weight, we run the maximum spanning tree here because higher probability means better

            batch_interpretation.append(essay_interpretation)

        return batch_interpretation