Exemple #1
0
    def _create_possible_rules(node, sentence):
        processor = LabelProcessor()

        anchors = node["anchors"]
        if len(anchors) == 0:
            return [{
                "rule":
                processor.make_absolute_label_rule(node["label"].lower()),
                "anchor":
                None
            }]

        rules = processor.gen_all_label_rules(
            [sentence["input"][anchor] for anchor in anchors],
            [sentence["lemmas"][anchor] for anchor in anchors],
            node["label"],
            separators=['', '_', ' '],
            rule_classes=[
                "absolute", "relative_forms", "relative_lemmas",
                "numerical_all"
            ],
            num_separator='_',
            concat=True,
            allow_copy=False,
        )
        return [{"rule": rule, "anchor": anchors} for rule in set(rules)]
Exemple #2
0
    def _create_possible_rules(node, sentence):
        processor = LabelProcessor()
        rules = []

        for anchor in node["anchors"]:
            _rules = processor.gen_all_label_rules(
                sentence["input"][anchor:],
                sentence["lemmas"][anchor:],
                node["label"],
                separators=['', '-'],
                rule_classes=[
                    "relative_forms", "relative_lemmas", "numerical_divide",
                    "concatenate"
                ],
                concat=False,
                allow_copy=False,
            )
            for rule in set(_rules):
                rules.append({"rule": rule, "anchor": [anchor]})

        rules.append({
            "rule": processor.make_absolute_label_rule(node["label"]),
            "anchor": None
        })

        return rules
Exemple #3
0
    def __init__(self, args):
        self.sos, self.eos, self.pad, self.unk = "<sos>", "<eos>", "<pad>", "<unk>"
        self.processor = LabelProcessor()

        self.bert_input_field = BertField()
        self.scatter_field = BasicField()
        self.every_word_input_field = Field(lower=True,
                                            init_token=self.sos,
                                            eos_token=self.eos,
                                            batch_first=True,
                                            include_lengths=True)
        self.every_lemma_field = Field(lower=True,
                                       init_token=self.sos,
                                       eos_token=self.eos,
                                       batch_first=True)

        char_form_nesting = torchtext.data.Field(tokenize=char_tokenize,
                                                 init_token=self.sos,
                                                 eos_token=self.eos,
                                                 batch_first=True)
        self.char_form_field = NestedField(char_form_nesting,
                                           include_lengths=True)

        char_lemma_nesting = torchtext.data.Field(tokenize=char_tokenize,
                                                  init_token=self.sos,
                                                  eos_token=self.eos,
                                                  batch_first=True)
        self.char_lemma_field = NestedField(char_lemma_nesting,
                                            include_lengths=True)

        self.relative_label_field = RelativeLabelField(
            args.label_smoothing,
            preprocessing=lambda nodes: [n["possible rules"] for n in nodes],
        )
        self.label_field = Field(
            preprocessing=lambda nodes: [n["label"].lower() for n in nodes],
            include_lengths=True,
            batch_first=True,
        )
        self.property_field = PropertyField(
            preprocessing=lambda nodes: [n["properties"] for n in nodes])

        self.id_field = Field(batch_first=True)
        self.edge_presence_field = EdgeField()
        self.edge_attribute_field = EdgeField()
        self.edge_label_field = EdgeLabelField()
        self.edge_permutation_field = EdgePermutationField()
        self.anchor_field = AnchorField()
        self.token_interval_field = BasicField()
        self.top_field = BasicField()
Exemple #4
0
def change_unnecessary_relative_rules(data):
    processor = LabelProcessor()

    label_sets = {}
    for n, _ in node_generator(data):
        for rule in n["possible rules"]:
            rule = rule["rule"]
            if rule not in label_sets:
                label_sets[rule] = set()
            label_sets[rule].add(n["label"].lower())

    for n, _ in node_generator(data):
        for i, rule in enumerate(n["possible rules"]):
            rule = rule["rule"]
            if len(label_sets[rule]) == 1:
                absolute_label = processor.make_absolute_label_rule(n['label'])
                n["possible rules"][i] = {
                    "rule": absolute_label,
                    "anchor": None
                }
Exemple #5
0
class Dataset:
    def __init__(self, args):
        self.sos, self.eos, self.pad, self.unk = "<sos>", "<eos>", "<pad>", "<unk>"
        self.processor = LabelProcessor()

        self.bert_input_field = BertField()
        self.scatter_field = BasicField()
        self.every_word_input_field = Field(lower=True,
                                            init_token=self.sos,
                                            eos_token=self.eos,
                                            batch_first=True,
                                            include_lengths=True)
        self.every_lemma_field = Field(lower=True,
                                       init_token=self.sos,
                                       eos_token=self.eos,
                                       batch_first=True)

        char_form_nesting = torchtext.data.Field(tokenize=char_tokenize,
                                                 init_token=self.sos,
                                                 eos_token=self.eos,
                                                 batch_first=True)
        self.char_form_field = NestedField(char_form_nesting,
                                           include_lengths=True)

        char_lemma_nesting = torchtext.data.Field(tokenize=char_tokenize,
                                                  init_token=self.sos,
                                                  eos_token=self.eos,
                                                  batch_first=True)
        self.char_lemma_field = NestedField(char_lemma_nesting,
                                            include_lengths=True)

        self.relative_label_field = RelativeLabelField(
            args.label_smoothing,
            preprocessing=lambda nodes: [n["possible rules"] for n in nodes],
        )
        self.label_field = Field(
            preprocessing=lambda nodes: [n["label"].lower() for n in nodes],
            include_lengths=True,
            batch_first=True,
        )
        self.property_field = PropertyField(
            preprocessing=lambda nodes: [n["properties"] for n in nodes])

        self.id_field = Field(batch_first=True)
        self.edge_presence_field = EdgeField()
        self.edge_attribute_field = EdgeField()
        self.edge_label_field = EdgeLabelField()
        self.edge_permutation_field = EdgePermutationField()
        self.anchor_field = AnchorField()
        self.token_interval_field = BasicField()
        self.top_field = BasicField()

    def load_state_dict(self, args, d):
        self.property_keys = d["property keys"]
        self.property_field.vocabs = pickle.loads(d["property vocabs"])
        for key, value in d["vocabs"].items():
            getattr(self, key).vocab = pickle.loads(value)

    def state_dict(self):
        return {
            "property vocabs": pickle.dumps(self.property_field.vocabs),
            "property keys": self.property_keys,
            "vocabs": {
                key: pickle.dumps(value.vocab)
                for key, value in self.__dict__.items()
                if hasattr(value, "vocab")
            }
        }

    def load_sentences(self, sentences, args, language: str):
        dataset = RequestParser(
            sentences,
            args,
            language,
            fields={
                "input": [("every_input", self.every_word_input_field),
                          ("char_form_input", self.char_form_field)],
                "bert input": ("input", self.bert_input_field),
                "lemmas": [("every_lemma", self.every_lemma_field),
                           ("char_lemma_input", self.char_lemma_field)],
                "to scatter": ("input_scatter", self.scatter_field),
                "token anchors":
                ("token_intervals", self.token_interval_field),
                "id": ("id", self.id_field),
            },
        )

        self.every_word_input_field.build_vocab(
            dataset,
            min_freq=1,
            specials=[self.pad, self.unk, self.sos, self.eos])
        self.every_lemma_field.build_vocab(
            dataset,
            min_freq=1,
            specials=[self.pad, self.unk, self.sos, self.eos])
        self.id_field.build_vocab(dataset, min_freq=1, specials=[])

        return dataset

    def load_dataset(self, args, gpu, n_gpus, framework: str, language: str):
        dataset = {
            ("dcr", "eng"): DcrParser,
            ("amr", "eng"): AMRParser,
            ("amr", "zho"): EDSParser,
            ("drg", "eng"): DRGParser,
            ("drg", "deu"): DRGParser,
            ("eds", "eng"): EDSParser,
            ("ptg", "eng"): PTGParser,
            ("ptg", "ces"): PTGParser,
            ("ucca", "eng"): UCCAParser,
            ("ucca", "deu"): UCCAParser,
        }[(framework, language)]

        self.train = dataset(
            args,
            framework,
            language,
            "training",
            fields={
                "input": [("every_input", self.every_word_input_field),
                          ("char_form_input", self.char_form_field)],
                "lemmas": ("char_lemma_input", self.char_lemma_field),
                "bert input": ("input", self.bert_input_field),
                "to scatter": ("input_scatter", self.scatter_field),
                "edge permutations":
                ("edge_permutations", self.edge_permutation_field),
                "nodes": [
                    ("labels", self.label_field),
                    ("relative_labels", self.relative_label_field),
                    ("properties", self.property_field),
                ],
                "top": ("top", self.top_field),
                "edge presence": ("edge_presence", self.edge_presence_field),
                "edge labels": ("edge_labels", self.edge_label_field),
                "edge attributes":
                ("edge_attributes", self.edge_attribute_field),
                "anchor edges": ("anchor", self.anchor_field),
            },
            filter_pred=lambda example: len(example.input) <=
            1080,  #changed - from 80
        )

        self.val = dataset(
            args,
            framework,
            language,
            "validation",
            fields={
                "input": [("every_input", self.every_word_input_field),
                          ("char_form_input", self.char_form_field)],
                "bert input": ("input", self.bert_input_field),
                "lemmas": [("every_lemma", self.every_lemma_field),
                           ("char_lemma_input", self.char_lemma_field)],
                "to scatter": ("input_scatter", self.scatter_field),
                "edge permutations":
                ("edge_permutations", self.edge_permutation_field),
                "nodes": [
                    ("labels", self.label_field),
                    ("relative_labels", self.relative_label_field),
                    ("properties", self.property_field),
                ],
                "top": ("top", self.top_field),
                "edge presence": ("edge_presence", self.edge_presence_field),
                "edge labels": ("edge_labels", self.edge_label_field),
                "edge attributes":
                ("edge_attributes", self.edge_attribute_field),
                "anchor edges": ("anchor", self.anchor_field),
                "token anchors":
                ("token_intervals", self.token_interval_field),
                "id": ("id", self.id_field),
            },
            precomputed_dataset=self.train.data,
        )

        self.test = EvaluationParser(
            args,
            framework,
            language,
            fields={
                "input": [("every_input", self.every_word_input_field),
                          ("char_form_input", self.char_form_field)],
                "bert input": ("input", self.bert_input_field),
                "lemmas": [("every_lemma", self.every_lemma_field),
                           ("char_lemma_input", self.char_lemma_field)],
                "to scatter": ("input_scatter", self.scatter_field),
                "token anchors":
                ("token_intervals", self.token_interval_field),
                "id": ("id", self.id_field),
            },
        )

        del self.train.data, self.val.data, self.test.data
        for f in list(self.train.fields.values()) + list(
                self.val.fields.values()) + list(self.test.fields.values()):
            if hasattr(f, "preprocessing"):
                del f.preprocessing

        self.train_size = len(self.train)
        self.val_size = len(self.val)
        self.test_size = len(self.test)

        print(f"\n{self.train_size} sentences in the train split")
        print(f"{self.val_size} sentences in the validation split")
        print(f"{self.test_size} sentences in the test split")

        self.node_count = self.train.node_counter
        self.token_count = self.train.input_count
        self.edge_count = self.train.edge_counter
        self.no_edge_count = self.train.no_edge_counter
        self.anchor_freq = self.train.anchor_freq
        print(f"{self.node_count} nodes in the train split")

        self.every_word_input_field.build_vocab(
            self.val,
            self.test,
            min_freq=1,
            specials=[self.pad, self.unk, self.sos, self.eos])
        self.char_form_field.build_vocab(
            self.train,
            min_freq=1,
            specials=[self.pad, self.unk, self.sos, self.eos])
        self.char_lemma_field.build_vocab(
            self.train,
            min_freq=1,
            specials=[self.pad, self.unk, self.sos, self.eos])
        self.every_lemma_field.build_vocab(
            self.val,
            self.test,
            min_freq=1,
            specials=[self.pad, self.unk, self.sos, self.eos])
        self.label_field.build_vocab(self.train, min_freq=1, specials=[])
        self.property_field.build_vocab(self.train)
        self.relative_label_field.build_vocab(self.train)
        self.id_field.build_vocab(self.val, self.test, min_freq=1, specials=[])
        self.edge_label_field.build_vocab(self.train)
        self.edge_attribute_field.build_vocab(self.train)

        self.create_label_freqs(args)
        self.create_edge_freqs(args)
        self.create_property_freqs(args)
        self.create_top_freqs(args)

        self.property_keys = self.property_field.keys
        print("properties: ", self.property_field.keys)

        print(f"Edge frequency: {self.edge_presence_freq*100:.2f} %")
        print(
            f"{len(self.relative_label_field.vocab)} words in the relative label vocabulary"
        )
        print(
            f"{len(self.edge_label_field.vocab)} words in the edge label vocabulary"
        )
        print(
            f"{len(self.char_form_field.vocab)} characters in the vocabulary")

        Random(42).shuffle(self.train.examples)
        self.train.examples = self.train.examples[:len(self.train.examples) //
                                                  n_gpus * n_gpus]
        self.train.examples = self.train.examples[gpu *
                                                  len(self.train.examples) //
                                                  n_gpus:(gpu + 1) *
                                                  len(self.train.examples) //
                                                  n_gpus]

    def relative_output_tensor_to_str(self,
                                      relative_label,
                                      anchors,
                                      tokens,
                                      lemmas,
                                      concat_rules: bool,
                                      num_lemmas=False,
                                      ignore_nonalnum=False):
        relative_label_str = self.relative_label_field.vocab.itos[
            relative_label]

        input_strings = [
            self.every_word_input_field.vocab.itos[i.item()]
            for i in tokens[anchors]
        ]
        lemma_strings = [
            self.every_lemma_field.vocab.itos[i.item()]
            for i in lemmas[anchors]
        ]

        absolute_label = self.processor.apply_label_rule(
            input_strings, lemma_strings, relative_label_str, concat_rules,
            num_lemmas, ignore_nonalnum)

        if absolute_label is None:
            absolute_label = "<none>"
        return absolute_label

    def create_label_freqs(self, args):
        n_rules = len(self.relative_label_field.vocab)
        blank_count = (args.query_length * self.token_count -
                       self.node_count) * args.blank_weight
        blank_p = blank_count * (
            1.0 - args.label_smoothing
        ) + self.node_count * args.label_smoothing / n_rules
        non_blank_p = blank_count * args.label_smoothing / n_rules
        label_counts = [blank_p] + [
            self.train.rule_counter[self.relative_label_field.vocab.itos[i]] +
            non_blank_p for i in range(n_rules)
        ]
        label_counts = torch.FloatTensor(label_counts)
        self.label_freqs = label_counts / (self.node_count + blank_count)

    def create_edge_freqs(self, args):
        edge_counter = [
            self.edge_label_field.vocab.freqs[
                self.edge_label_field.vocab.itos[i]]
            for i in range(len(self.edge_label_field.vocab))
        ]
        edge_counter = torch.FloatTensor(edge_counter)
        self.edge_label_freqs = edge_counter / self.edge_count
        self.edge_presence_freq = self.edge_count / (self.edge_count +
                                                     self.no_edge_count)
        edge_attribute_counter = [
            self.edge_attribute_field.vocab.freqs[
                self.edge_attribute_field.vocab.itos[i]]
            for i in range(len(self.edge_attribute_field.vocab))
        ]
        edge_attribute_counter = torch.FloatTensor(edge_attribute_counter)
        self.edge_attribute_freqs = edge_attribute_counter / self.edge_count

    def create_property_freqs(self, args):
        property_counter = {
            key: [vocab.freqs[vocab.itos[i]] for i in range(len(vocab))]
            for key, vocab in self.property_field.vocabs.items()
        }
        self.property_freqs = {
            key: torch.FloatTensor(c) / self.node_count
            for key, c in property_counter.items()
        }

    def create_top_freqs(self, args):
        self.top_freq = self.train_size / self.node_count
Exemple #6
0
#!/usr/bin/env python3
# conding=utf-8
#
# Copyright 2020 Institute of Formal and Applied Linguistics, Faculty of
# Mathematics and Physics, Charles University, Czech Republic.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from utility.label_processor import LabelProcessor

processor = LabelProcessor()


def check(forms, lemma, label, rule_classes, concat: bool, num_separator=''):
    rules = list(
        processor.gen_all_label_rules(forms,
                                      forms,
                                      label,
                                      rule_classes, ['', '+', '-'],
                                      concat,
                                      num_separator=num_separator))

    for rule in rules:
        assert processor.apply_label_rule(forms, forms, rule, concat) == label

    return rules


def test_all_1():