Пример #1
0
    def dump_line(self, outputs: JsonDict) -> str:
        # function hijacked from parent class to return a decomp arborescence instead of printing a line
        pred_sem_graph, pred_syn_graph, conllu_graph = DecompGraphWithSyntax.from_prediction(
            outputs, self._model.syntactic_method)

        if conllu_graph is not None:
            if self._model.syntactic_method in [
                    'concat-before', 'concat-after'
            ]:
                text = " ".join([row["form"] for row in conllu_graph])
                outputs['syn_nodes'] = text.split(" ")
            else:
                text = " ".join(outputs['syn_nodes'])
            id = 1

            conllu_str = f"# sent_id = train-s{id}\n" +\
                         f"# text = {text}\n" + \
                         f"# org_sent_id = {id}\n"
            colnames = [
                "ID", "form", "lemma", "upos", "xpos", "feats", "head",
                "deprel", "deps", "misc"
            ]

            n_vals = 0
            n_rows = len(conllu_graph)
            for row in conllu_graph:
                vals = [row[cn] for cn in colnames]
                conllu_str += "\t".join(vals) + "\n"
                n_vals = len(vals)

            #if self._model.syntactic_method not in ['concat-before', 'concat-after']:
            # cases where we had to trim
            if len(outputs['syn_nodes']) > n_rows:
                c = n_rows
                for node in outputs['syn_nodes'][n_rows:]:
                    #vals = [str(c+1)] + [node] + ["-" for i in range(n_vals-2)]
                    dummy_row = {
                        "ID": str(c + 1),
                        "form": node,
                        "lemma": "-",
                        "upos": "-",
                        "xpos": "-",
                        "feats": "-",
                        "head": str(1),
                        "deprel": "amod",
                        "deps": "-",
                        "misc": "-"
                    }
                    vals = [dummy_row[cn] for cn in colnames]
                    conllu_str += "\t".join(vals) + "\n"
                    c += 1

            conllu_str += '\n'
        else:
            conllu_str = ""

        return pred_sem_graph, pred_syn_graph, conllu_str
Пример #2
0
def test_get_list_data_concat_before(load_dev_graphs):
    # test concat-after
    d_graph = DecompGraphWithSyntax(load_dev_graphs['basic'],
                                    syntactic_method="concat-before")
    list_data = d_graph.get_list_data(bos="@start@",
                                      eos="@end@",
                                      max_tgt_length=100)
    expected = {
        "tgt_tokens": [
            '@start@', 'comes', 'AP', 'story', ':', 'From', 'the', 'this',
            '@syntax-sep@', '@@ROOT@@', 'comes', 'AP', 'the', 'story', 'this',
            'From', '@end@'
        ],
        "head_indices": [0, 1, 1, 1, 2, 2, 3, -1, 0, 9, 10, 11, 10, 13, 10],
        "head_tags": [
            'root', 'nmod', 'nsubj', 'punct', 'case', 'det', 'det', 'SEP',
            'dependency', 'dependency', 'dependency', 'EMPTY', 'dependency',
            'EMPTY', 'EMPTY'
        ]
    }

    assert_dict(list_data, expected)
    def _update_validation_s_score(self,
                                   pred_instances: List[Dict[str,
                                                             numpy.ndarray]],
                                   true_instances):
        """Write the validation output in pkl format, and compute the S score."""
        # compute attachement scores here without having to override another function
        self._update_attachment_scores(pred_instances, true_instances)

        if isinstance(self.model, DecompSyntaxOnlyParser) or \
           isinstance(self.model, DecompTransformerSyntaxOnlyParser) or \
           isinstance(self.model, UDParser):
            return

        logger.info("Computing S")

        for batch in true_instances:
            assert (len(batch) == 1)

        true_graphs = [
            true_inst for batch in true_instances
            for true_inst in batch[0]['graph']
        ]
        true_sents = [
            true_inst for batch in true_instances
            for true_inst in batch[0]['src_tokens_str']
        ]

        pred_graphs = [
            DecompGraphWithSyntax.from_prediction(pred_inst,
                                                  self.syntactic_method)
            for pred_inst in pred_instances
        ]

        pred_sem_graphs, pred_syn_graphs, __ = zip(*pred_graphs)

        ret = compute_s_metric(true_graphs, pred_sem_graphs, true_sents,
                               self.semantics_only, self.drop_syntax,
                               self.include_attribute_scores)

        self.model.val_s_precision = float(ret[0]) * 100
        self.model.val_s_recall = float(ret[1]) * 100
        self.model.val_s_f1 = float(ret[2]) * 100
Пример #4
0
def parse_api_sentence(input_line, args, predictor):
    #semantics_only = args.semantics_only
    #drop_syntax = args.drop_syntax

    manager = _ReturningPredictManager(predictor=predictor,
                                       input_file=input_line,
                                       output_file=None,
                                       batch_size=1,
                                       print_to_console=False,
                                       has_dataset_reader=True,
                                       beam_size=2,
                                       line_limit=1,
                                       oracle=False,
                                       json_output_file=None)

    manager._dataset_reader.api_time = True

    if isinstance(predictor, DecompSyntaxParsingPredictor):
        sem_graph, syn_graph, __ = manager.run()[1][0]
        return DecompGraphWithSyntax.arbor_to_uds(sem_graph, syn_graph,
                                                  "test-graph", input_line)

    return DecompGraph.arbor_to_uds(manager.run()[1][0], "test-graph",
                                    input_line)
Пример #5
0
    def text_to_instance(self, graph, do_print=False) -> Instance:
        """
        Does bulk of work converting a graph to an Instance of Fields 
        """
        # pylint: disable=arguments-differ

        fields: Dict[str, Field] = {}

        max_tgt_length = None if self.eval else 90
        d = DecompGraphWithSyntax(graph,
                                  drop_syntax=self.drop_syntax,
                                  order=self.order,
                                  syntactic_method=self.syntactic_method,
                                  full_ud_parse=self.full_ud_parse)

        list_data = d.get_list_data(bos=START_SYMBOL,
                                    eos=END_SYMBOL,
                                    bert_tokenizer=self._tokenizer,
                                    max_tgt_length=max_tgt_length,
                                    semantics_only=self.semantics_only)

        if list_data is None:
            return None

        if do_print:
            self.spot_check(graph, list_data)

        # These four fields are used for seq2seq model and target side self copy
        fields["source_tokens"] = TextField(
            tokens=[Token(x) for x in list_data["src_tokens"]],
            token_indexers=self._source_token_indexers)

        if list_data['src_token_ids'] is not None:
            fields['source_subtoken_ids'] = ArrayField(
                list_data['src_token_ids'])
            self._number_bert_ids += len(list_data['src_token_ids'])
            self._number_bert_oov_ids += len([
                bert_id for bert_id in list_data['src_token_ids']
                if bert_id == 100
            ])

        if list_data['src_token_subword_index'] is not None:
            fields['source_token_recovery_matrix'] = ArrayField(
                list_data['src_token_subword_index'])

        # Target-side input.
        # (exclude the last one <EOS>.)
        fields["target_tokens"] = TextField(
            tokens=[Token(x) for x in list_data["tgt_tokens"][:-1]],
            token_indexers=self._target_token_indexers)

        if len(list_data['tgt_tokens']) > 60:
            self.over_len += 1

        fields["source_pos_tags"] = SequenceLabelField(
            labels=list_data["src_pos_tags"],
            sequence_field=fields["source_tokens"],
            label_namespace="pos_tags")

        if list_data["tgt_pos_tags"] is not None:
            fields["target_pos_tags"] = SequenceLabelField(
                labels=list_data["tgt_pos_tags"][:-1],
                sequence_field=fields["target_tokens"],
                label_namespace="pos_tags")

        fields["target_node_indices"] = SequenceLabelField(
            labels=list_data["tgt_indices"][:-1],
            sequence_field=fields["target_tokens"],
            label_namespace="node_indices",
        )

        # Target-side output.
        # Include <BOS> here because we want it in the generation vocabulary such that
        # at the inference starting stage, <BOS> can be correctly initialized.
        fields["generation_outputs"] = TextField(
            tokens=[Token(x) for x in list_data["tgt_tokens_to_generate"]],
            token_indexers=self._generation_token_indexers)

        fields["target_copy_indices"] = SequenceLabelField(
            labels=list_data["tgt_copy_indices"],
            sequence_field=fields["generation_outputs"],
            label_namespace="target_copy_indices",
        )

        fields[
            "target_attention_map"] = AdjacencyField(  # TODO: replace it with ArrayField.
                indices=list_data["tgt_copy_map"],
                sequence_field=fields["generation_outputs"],
                padding_value=0)

        # These two fields for source copy

        fields["source_copy_indices"] = SequenceLabelField(
            labels=list_data["src_copy_indices"],
            sequence_field=fields["generation_outputs"],
            label_namespace="source_copy_indices",
        )

        fields[
            "source_attention_map"] = AdjacencyField(  # TODO: replace it with ArrayField.
                indices=list_data["src_copy_map"],
                sequence_field=TextField([
                    Token(x) for x in
                    list_data["src_copy_vocab"].get_special_tok_list() +
                    list_data["src_tokens"]
                ], None),
                padding_value=0)

        # These two fields are used in biaffine parser
        fields["edge_types"] = TextField(
            tokens=[Token(x) for x in list_data["head_tags"]],
            token_indexers=self._edge_type_indexers)

        fields["edge_heads"] = SequenceLabelField(
            labels=list_data["head_indices"],
            sequence_field=fields["edge_types"],
            label_namespace="edge_heads")

        if list_data.get('node_mask', None) is not None:
            # Valid nodes are 1; pads are 0.
            fields['valid_node_mask'] = ArrayField(list_data['node_mask'])

        if list_data.get('edge_mask', None) is not None:
            # A matrix of shape [num_nodes, num_nodes] where entry (i, j) is 1
            # if and only if (1) j < i and (2) j is not an antecedent of i.
            # TODO: try to remove the second constrain.
            fields['edge_head_mask'] = ArrayField(list_data['edge_mask'])

        # node attributes
        #print(f"tgt attr {len(list_data['tgt_attributes'])}")
        #print(list_data['tgt_attributes'])
        #print(f"target tokens {len(fields['target_tokens'])}")
        #print(fields['target_tokens'])

        fields["target_attributes"] = ContinuousLabelField(
            labels=list_data["tgt_attributes"][:-1],
            sequence_field=fields["target_tokens"],
            ontology=NODE_ONTOLOGY)

        # edge attributes
        fields["edge_attributes"] = ContinuousLabelField(
            labels=list_data["edge_attributes"][:-1],
            sequence_field=fields["target_tokens"],
            ontology=EDGE_ONTOLOGY)

        # this field is actually needed for scoring later
        fields["graph"] = MetadataField(list_data['arbor_graph'])

        fields["true_conllu_dict"] = MetadataField(
            list_data['true_conllu_dict'])

        # Metadata fields, good for debugging
        fields["src_tokens_str"] = MetadataField(list_data["src_tokens"])

        fields["edge_types_str"] = MetadataField(list_data['head_tags'])

        fields["tgt_tokens_str"] = MetadataField(
            list_data.get("tgt_tokens", []))

        fields["src_copy_vocab"] = MetadataField(list_data["src_copy_vocab"])

        fields["tag_lut"] = MetadataField(dict(pos=list_data["pos_tag_lut"]))

        fields["source_copy_invalid_ids"] = MetadataField(
            list_data['src_copy_invalid_ids'])

        fields["node_name_list"] = MetadataField(list_data['node_name_list'])
        fields["target_dynamic_vocab"] = MetadataField(dict())

        fields["instance_meta"] = MetadataField(
            dict(
                pos_tag_lut=list_data["pos_tag_lut"],
                source_dynamic_vocab=list_data["src_copy_vocab"],
                target_token_indexers=self._target_token_indexers,
            ))

        if self.syntactic_method == "encoder-side":

            fields["syn_edge_types"] = TextField(
                tokens=[Token(x) for x in list_data["syn_head_tags"]],
                token_indexers=self._syntax_edge_type_indexers,
            )

            fields["syn_edge_heads"] = SequenceLabelField(
                labels=list_data["syn_head_indices"],
                sequence_field=fields["syn_edge_types"],
                label_namespace="syn_edge_heads")

            fields['syn_edge_head_mask'] = ArrayField(
                list_data['syn_edge_mask'])
            fields['syn_valid_node_mask'] = ArrayField(
                list_data['syn_node_mask'])

            fields["syn_tokens_str"] = MetadataField(list_data["syn_tokens"])

            fields["syn_node_name_list"] = MetadataField(
                list_data["syn_node_name_list"])

            #fields["op_vec"] = ArrayField(list_data['op_vec'])

        to_print_keys = ["target_attributes", "target_tokens"]
        to_print = {k: v for k, v in fields.items() if k in to_print_keys}

        #print_fields = fields["syn_edge_types"]
        #print(Instance(print_fields))
        #sys.exit()

        return Instance(fields)
Пример #6
0
def get_decomp_graph(pp_graph):
    dg = DecompGraphWithSyntax(pp_graph)
    return dg