def create_frame_prediction_pairs(self):
     return [
         FramePredictionPair(
             CompositionalMetricReporter.tree_to_metric_node(pred_tree),
             CompositionalMetricReporter.tree_to_metric_node(target_tree),
         ) for pred_tree, target_tree in zip(self.all_pred_trees,
                                             self.all_target_trees)
     ]
Example #2
0
    def aggregate_preds(self, new_batch, context=None):
        if new_batch is None:
            return
        tree_preds = new_batch[0]  # bsz X beam_size X seq_len
        length_preds = new_batch[1]
        target_vocab = self.tensorizers["trg_seq_tokens"].vocab
        target_pad_token = target_vocab.get_pad_index()
        target_bos_token = target_vocab.get_bos_index()
        target_eos_token = target_vocab.get_eos_index()

        cleaned_preds = [
            self._remove_tokens(
                pred, [target_pad_token, target_eos_token, target_bos_token])
            for pred in self._make_simple_list(tree_preds)
        ]
        self.aggregate_data(self.all_preds, cleaned_preds)

        pred_trees = [
            self.stringify_annotation_tree(pred[0], target_vocab)
            for pred in cleaned_preds
        ]

        beam_pred_trees = [[
            CompositionalMetricReporter.tree_to_metric_node(
                self.stringify_annotation_tree(pred, target_vocab))
            for pred in beam
        ] for beam in cleaned_preds]

        top_non_invalid_trees = [
            self.get_annotation_from_string(
                self.get_top_non_invalid(
                    [stringify(pred, target_vocab) for pred in beam]))
            for beam in cleaned_preds
        ]

        top_extracted_trees = [
            self.get_annotation_from_string(
                self.get_top_extract(
                    [stringify(pred, target_vocab) for pred in beam]))
            for beam in cleaned_preds
        ]

        self.aggregate_data(self.all_pred_trees, pred_trees)
        self.aggregate_data(self.all_target_length_preds, length_preds)
        self.aggregate_data(self.all_beam_preds, beam_pred_trees)
        self.aggregate_data(self.all_top_non_invalid, top_non_invalid_trees)
        self.aggregate_data(self.all_top_extract, top_extracted_trees)
 def test_tree_to_metric_node(self):
     TEXT_EXAMPLES = [
         (
             "[IN:alarm/set_alarm  repeat the [SL:datetime 3 : 00 pm ] "
             + "[SL:alarm/name alarm ]  [SL:datetime for Sunday august 12th ]  ] ",
             Node(
                 label="IN:alarm/set_alarm",
                 span=Span(start=0, end=49),
                 children={
                     Node(label="SL:datetime", span=Span(start=11, end=20)),
                     Node(label="SL:alarm/name", span=Span(start=21, end=26)),
                     Node(label="SL:datetime", span=Span(start=27, end=49)),
                 },
             ),
         ),
         (
             "[IN:calling/call_friend call [SL:person moms ] cellphone ]",
             Node(
                 label="IN:calling/call_friend",
                 span=Span(start=0, end=19),
                 children={Node(label="SL:person", span=Span(start=5, end=9))},
             ),
         ),
         (
             "[IN:GET_DIRECTIONS I need [SL:ANCHOR directions] to [SL:DESTINATION "
             + "[IN:GET_EVENT the jazz festival]]]",
             Node(
                 label="IN:GET_DIRECTIONS",
                 span=Span(start=0, end=38),
                 children={
                     Node(label="SL:ANCHOR", span=Span(start=7, end=17)),
                     Node(
                         label="SL:DESTINATION",
                         span=Span(start=21, end=38),
                         children={
                             Node(label="IN:GET_EVENT", span=Span(start=21, end=38))
                         },
                     ),
                 },
             ),
         ),
     ]
     for annotation_string, expected_frame in TEXT_EXAMPLES:
         annotation = Annotation(annotation_string)
         frame = CompositionalMetricReporter.tree_to_metric_node(annotation.tree)
         self.assertEqual(frame, expected_frame)
Example #4
0
def get_frame(parse: str) -> Node:
    annotation = Annotation(parse)
    frame = CompositionalMetricReporter.tree_to_metric_node(annotation.tree)
    return frame