コード例 #1
0
    def get_utterance_semantics_pairs(self,
                                      random_source,
                                      rule_sets,
                                      branch_cap=None):
        all_pairs = {}
        rules = [self.rules[index - 1] for index in rule_sets]

        for rules, rules_anon, rules_ground, semantics in rules:
            cat_groundings = {}

            pairs = []
            if self.semantic_form_version == "slot":
                pairs = generate_sentence_slot_pairs(
                    ROOT_SYMBOL,
                    rules_ground,
                    semantics,
                    yield_requires_semantics=True,
                    branch_cap=branch_cap,
                    random_generator=random_source)
            else:
                pairs = generate_sentence_parse_pairs(
                    ROOT_SYMBOL,
                    rules_ground,
                    semantics,
                    yield_requires_semantics=True,
                    branch_cap=branch_cap,
                    random_generator=random_source)

            for utterance, parse in pairs:
                all_pairs[tree_printer(utterance)] = tree_printer(parse)
            #for sentence, semantics in pairs:
            #    print(tree_printer(sentence))
            #    print(tree_printer(semantics))
        return all_pairs
コード例 #2
0
def main():
    random_source = random.Random(seed)
    grammar_dir = os.path.abspath(
        os.path.dirname(__file__) + "/../../resources/generator2018")
    out_file_path = os.path.abspath(
        os.path.dirname(__file__) +
        "/../../data/rephrasings_data_{}_{}.csv".format(
            seed, groundings_per_parse))
    cmd_gen = Generator(grammar_format_version=2018)
    generator = load_all_2018_by_cat(cmd_gen, grammar_dir)

    all_examples = []
    for i in range(groundings_per_parse):
        grounded_examples = get_grounding_per_each_parse(
            generator, random_source)
        random_source.shuffle(grounded_examples)
        all_examples += grounded_examples

    with open(out_file_path, 'w') as csvfile:
        output = csv.writer(csvfile, quoting=csv.QUOTE_MINIMAL)
        command_columns = [("command" + str(x), "parse" + str(x),
                            "parse_ground" + str(x))
                           for x in range(1, rehprasings_per_hit + 1)]
        command_columns = [x for tuple in command_columns for x in tuple]
        output.writerow(command_columns)

        chunks = list(chunker(all_examples, rehprasings_per_hit))
        print("Writing {} HITS".format(len(chunks)))
        for i, chunk in enumerate(chunks):
            if len(chunk) < rehprasings_per_hit:
                needed = rehprasings_per_hit - len(chunk)
                # Sample from previous hits to fill out this last one
                chunk += random_source.sample(
                    [pair for chunk in chunks[:i] for pair in chunk], k=needed)
            line = []
            for utterance, parse_anon, parse_ground in chunk:
                line += [
                    tree_printer(utterance),
                    tree_printer(parse_anon),
                    tree_printer(parse_ground)
                ]
            output.writerow(line)

    # Let's verify that we can load the output back in...
    with open(out_file_path, 'r') as csvfile:
        input = csv.DictReader(csvfile)
        for line in input:
            pass
コード例 #3
0
def pairs_without_placeholders(rules, semantics, only_in_grammar=False):
    pairs = expand_all_semantics(rules, semantics)
    out = {}
    all_utterances_in_grammar = set(generate_sentences(ROOT_SYMBOL, rules))
    for command, parse in pairs:
        if has_placeholders(command) or has_placeholders(parse):
            # This case is almost certainly a bug with the annotations
            print(
                "Skipping pair for {} because it still has placeholders after expansion"
                .format(tree_printer(command)))
            continue
        # If it's important that we only get pairs that are in the grammar, check to make sure
        if only_in_grammar and not command in all_utterances_in_grammar:
            continue
        out[tree_printer(command)] = tree_printer(parse)
    return out
コード例 #4
0
def get_annotated_sentences(sentences_and_pairs):
    sentences, pairs = sentences_and_pairs
    expanded_sentences = set([tree_printer(x) for x in sentences])
    annotated_sentences = set(pairs.keys())
    # Only keep annotations that cover sentences actually in the grammar
    useless_annotations = annotated_sentences.difference(expanded_sentences)
    annotated_sentences.intersection_update(expanded_sentences)
    return annotated_sentences
コード例 #5
0
    def test_nearest_neighbor_parser(self):
        generator = Generator(grammar_format_version=2018)
        rules = load_all_2019(generator, GRAMMAR_DIR)

        sentences = generate_sentences(ROOT_SYMBOL, rules[0])
        parser = GrammarBasedParser(rules[0])
        sentences = list(set([tree_printer(x) for x in sentences]))
        neighbors = [(sentence, parser(sentence)) for sentence in sentences]
        nearest_neighbor_parser = KNearestNeighborParser(neighbors)
        some_sentence = sentences[0]
        tweaked = some_sentence[:-1]
        expected_parse = parser(some_sentence)
        self.assertEqual(nearest_neighbor_parser(some_sentence),
                         expected_parse)
        self.assertEqual(nearest_neighbor_parser(tweaked), expected_parse)
コード例 #6
0
    def test_parse_all_of_2019(self):
        generator = Generator(grammar_format_version=2018)

        grammar_dir = os.path.abspath(
            os.path.dirname(__file__) + "/../resources/generator2019")
        rules, rules_anon, _, _, _ = load_all_2019(generator, grammar_dir)

        sentences = generate_sentences(ROOT_SYMBOL, rules)
        parser = GrammarBasedParser(rules)
        sentences = set([tree_printer(x) for x in sentences])
        succeeded = 0
        for sentence in sentences:
            parsed = parser(sentence)
            if parsed:
                succeeded += 1

        self.assertEqual(len(sentences), succeeded)
コード例 #7
0
    def test_parse_all_2019_anonymized(self):
        generator = Generator(grammar_format_version=2019)

        grammar_dir = os.path.abspath(
            os.path.dirname(__file__) + "/../resources/generator2019")
        rules, rules_anon, rules_ground, semantics, entities = load_all_2019(
            generator, grammar_dir)

        sentences = generate_sentence_parse_pairs(
            ROOT_SYMBOL,
            rules_ground, {},
            yield_requires_semantics=False,
            random_generator=random.Random(1))
        parser = GrammarBasedParser(rules_anon)

        # Bring me the apple from the fridge to the kitchen
        # ---straight anon to clusters--->
        # Bring me the {ob}  from the {loc} to the {loc}
        # ---Grammar based parser--->
        # (Failure; grammar has numbers on locs)

        # Bring me the apple from the fridge to the kitchen
        # ---id naive number anon--->
        # Bring me the {ob}  from the {loc 1} to the {loc 2}
        # ---Grammar based parser--->
        # (Failure; wrong numbers, or maybe)

        anonymizer = Anonymizer(*entities)
        parser = AnonymizingParser(parser, anonymizer)
        num_tested = 1000
        succeeded = 0
        for sentence, parse in itertools.islice(sentences, num_tested):
            sentence = tree_printer(sentence)
            parsed = parser(sentence)
            if parsed:
                succeeded += 1
            else:
                print(sentence)
                print(anonymizer(sentence))
                print()
                print(parser(anonymizer(sentence)))

        self.assertEqual(succeeded, num_tested)
コード例 #8
0
def load_data(path, lambda_parser):
    pairs = {}
    with open(path) as f:
        line_generator = more_itertools.peekable(enumerate(f))
        while line_generator:
            line_num, line = next(line_generator)
            line = line.strip("\n")
            if len(line) == 0:
                continue

            next_pair = line_generator.peek(None)
            if not next_pair:
                raise RuntimeError()
            next_line_num, next_line = next(line_generator)

            source_sequence, target_sequence = line, next_line

            try:
                pairs[source_sequence] = tree_printer(lambda_parser.parse(target_sequence))
            except lark.exceptions.LarkError:
                print("Skipping malformed parse: {}".format(target_sequence))
    return pairs
コード例 #9
0
    def test_parse_choice(self):
        test = self.generator.generator_grammar_parser.parse(
            "$test = ( oneword | two words)")
        print(test.pretty())
        test = self.generator.generator_grammar_parser.parse(
            "$test = ( front | back | main | rear ) $ndoor")
        print(test.pretty())

        top_choice = self.generator.generator_grammar_parser.parse(
            "$test = front | back")
        top_choice_short = self.generator.generator_grammar_parser.parse(
            "$test = (front | back)")
        self.assertEqual(top_choice, top_choice_short)

        short_mix_choice = self.generator.generator_grammar_parser.parse(
            "$test = aa | aa ba")
        print(short_mix_choice.pretty())
        complex_choice = self.generator.generator_grammar_parser.parse(
            "$phpeople    = everyone | all the (people | men | women | guests | elders | children)"
        )

        print(complex_choice.pretty())
        print(tree_printer(complex_choice))
コード例 #10
0
def main():
    out_root = os.path.abspath(os.path.dirname(__file__) + "/../../data")
    grammar_dir = os.path.abspath(
        os.path.dirname(__file__) + "/../../resources/generator2018")

    cmd_gen = Generator(grammar_format_version=2018)
    generator = load_all_2018_by_cat(cmd_gen, grammar_dir)

    cat_sentences = [
        set(generate_sentences(ROOT_SYMBOL, rules))
        for _, rules, _, _ in generator
    ]
    pairs = [{
        utterance: parse
        for utterance, parse in expand_all_semantics(rules, semantics)
    } for _, rules, _, semantics in generator]

    pairs = [
        pairs_without_placeholders(rules, semantics)
        for _, rules, _, semantics in generator
    ]
    by_utterance, by_parse = determine_unique_cat_data(
        pairs, keep_new_utterance_repeat_parse_for_lower_cat=False)
    unique = []
    for i, _ in enumerate(cat_sentences):
        prev_cats = cat_sentences[:i]
        if prev_cats:
            # Don't count the sentence as unique unless it hasn't happened in any earlier categories
            prev_cat_sentences = set().union(*prev_cats)
            overlapped_with_prev = cat_sentences[i].intersection(
                prev_cat_sentences)
            unique.append(cat_sentences[i].difference(prev_cat_sentences))
        else:
            unique.append(cat_sentences[i])
    # Sets should be disjoint
    assert (len(set().union(*unique)) == sum([len(cat) for cat in unique]))

    all_sentences = [tree_printer(x) for x in set().union(*cat_sentences)]
    all_pairs = pairs

    annotated = [get_annotated_sentences(x) for x in zip(cat_sentences, pairs)]

    unique_annotated = [
        get_annotated_sentences((unique_sen, cat_pairs))
        for unique_sen, cat_pairs in zip(unique, pairs)
    ]
    unique_sentence_parses = [[
        pairs[ann_sen] for ann_sen in annotated
    ] for annotated, pairs in zip(unique_annotated, pairs)]
    unique_sentence_parses = [set(x) for x in unique_sentence_parses]

    combined_annotations = set().union(*annotated)
    combined_annotations.intersection_update(all_sentences)

    parseless = [
        sen.difference(annotated_sentences)
        for sen, annotated_sentences in zip(cat_sentences, annotated)
    ]

    out_paths = [
        join(out_root,
             str(i) + "_sentences.txt") for i in range(1, 4)
    ]

    for cat_out_path, sentences in zip(out_paths, cat_sentences):
        with open(cat_out_path, "w") as f:
            for sentence in sentences:
                assert not has_placeholders(sentence)
                f.write(tree_printer(sentence) + '\n')

    out_paths = [join(out_root, str(i) + "_pairs.txt") for i in range(1, 4)]

    for cat_out_path, pairs in zip(out_paths, all_pairs):
        with open(cat_out_path, "w") as f:
            for sentence, parse in pairs.items():
                f.write(sentence + '\n' + parse + '\n')

    meta_out_path = join(out_root, "annotations_meta.txt")
    with open(meta_out_path, "w") as f:
        f.write("Coverage:\n")
        cat_parse_lengths = []
        cat_filtered_parse_lengths = []
        cat_sen_lengths = []
        for i, (annotated_sen, sen, unique_parses) in enumerate(
                zip(annotated, cat_sentences, unique_sentence_parses)):
            f.write("cat{0} {1}/{2} {3:.1f}%\n".format(
                i + 1, len(annotated_sen), len(sen),
                100.0 * len(annotated_sen) / len(sen)))
            f.write("\t unique parses: {}\n".format(len(unique_parses)))
            cat_sen_lengths.append(
                [len(tree_printer(sentence).split()) for sentence in sen])
            avg_sentence_length = np.mean(cat_sen_lengths[i])
            parse_lengths = []
            filtered_parse_lengths = []
            for parse in unique_parses:
                parse_lengths.append(len(parse.split()))
                stop_tokens_removed = re.sub("(\ e\ |\"|\)|\()", "", parse)
                filtered_parse_lengths.append(len(stop_tokens_removed.split()))
            cat_parse_lengths.append(parse_lengths)
            cat_filtered_parse_lengths.append(filtered_parse_lengths)
            avg_parse_length = np.mean(cat_parse_lengths[i])
            avg_filtered_parse_length = np.mean(cat_filtered_parse_lengths[i])
            f.write(
                "\t avg sentence length (tokens): {:.1f} avg parse length (tokens): {:.1f} avg filtered parse length (tokens): {:.1f}\n"
                .format(avg_sentence_length, avg_parse_length,
                        avg_filtered_parse_length))

        f.write("combined {0}/{1} {2:.1f}%\n".format(
            len(combined_annotations), len(all_sentences),
            100.0 * len(combined_annotations) / len(all_sentences)))
        f.write("combined unique parses: {}\n".format(
            len(set().union(*unique_sentence_parses))))

        all_sen_lengths = [length for cat in cat_sen_lengths for length in cat]
        all_parse_lengths = [
            length for cat in cat_parse_lengths for length in cat
        ]
        all_filtered_parse_lengths = [
            length for cat in cat_filtered_parse_lengths for length in cat
        ]
        f.write(
            "combined avg sentence length (tokens): {:.1f} avg parse length (tokens): {:.1f} avg filtered parse length (tokens): {:.1f}\n"
            .format(np.mean(all_sen_lengths), np.mean(all_parse_lengths),
                    np.mean(all_filtered_parse_lengths)))
    print("No parses for:")
    for cat in parseless:
        for sentence in sorted(map(tree_printer, cat)):
            print(sentence)
        print("-----------------")
コード例 #11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-s","--split", default=[.7,.1,.2], nargs='+', type=float)
    parser.add_argument("-trc","--train-categories", default=[1, 2, 3], nargs='+', type=int)
    parser.add_argument("-tc","--test-categories", default=[1, 2, 3], nargs='+', type=int)
    parser.add_argument("-p", "--use-form-split", action='store_true', default=False)
    parser.add_argument("-g","--groundings", required=False, type=int, default=None)
    parser.add_argument("-a","--anonymized", required=False, default=True, action="store_true")
    parser.add_argument("-m", "--match-form-split", required=False, default=None, type=str)
    parser.add_argument("-na","--no-anonymized", required=False, dest="anonymized", action="store_false")
    parser.add_argument("-ra", "--run-anonymizer", required=False, default=False, action="store_true")
    parser.add_argument("-t", "--paraphrasings", required=False, default=None, type=str)
    parser.add_argument("--name", default=None, type=str)
    parser.add_argument("--seed", default=0, required=False, type=int)
    parser.add_argument("-i","--incremental-datasets", action='store_true', required=False)
    parser.add_argument("-f", "--force-overwrite", action="store_true", required=False, default=False)
    args = parser.parse_args()

    validate_args(args)

    cmd_gen = Generator(grammar_format_version=2018)
    random_source = random.Random(args.seed)

    different_test_dist = (args.test_categories != args.train_categories)

    pairs_out_path = os.path.join(os.path.abspath(os.path.dirname(__file__) + "/../.."), "data", args.name)
    train_out_path = os.path.join(pairs_out_path, "train.txt")
    val_out_path = os.path.join(pairs_out_path, "val.txt")
    test_out_path = os.path.join(pairs_out_path, "test.txt")
    meta_out_path = os.path.join(pairs_out_path, "meta.txt")

    if args.force_overwrite and os.path.isdir(pairs_out_path):
        shutil.rmtree(pairs_out_path)
    os.mkdir(pairs_out_path)
    
    grammar_dir = os.path.abspath(os.path.dirname(__file__) + "/../../resources/generator2018")

    generator = load_all_2018_by_cat(cmd_gen, grammar_dir)

    pairs = [{}, {}, {}]
    if args.anonymized:
        pairs = [pairs_without_placeholders(rules, semantics) for _, rules, _, semantics in generator]

    # For now this only works with all data
    if args.groundings and len(args.train_categories) == 3:
        for i in range(args.groundings):
            groundings = get_grounding_per_each_parse_by_cat(generator,random_source)
            for cat_pairs, groundings in zip(pairs, groundings):
                for utt, form_anon, _ in groundings:
                    pairs[0][tree_printer(utt)] = tree_printer(form_anon)

    if args.paraphrasings and len(args.train_categories) == 3:
        paraphrasing_pairs = load_data(args.paraphrasings, cmd_gen.lambda_parser)
        if args.run_anonymizer:
            paths = tuple(
                map(lambda x: join(grammar_dir, x), ["objects.xml", "locations.xml", "names.xml", "gestures.xml"]))
            entities = load_entities_from_xml(*paths)
            anonymizer = Anonymizer(*entities)
            anon_para_pairs = {}
            anon_trigerred = 0
            for command, form in paraphrasing_pairs.items():
                anonymized_command = anonymizer(command)
                if anonymized_command != command:
                    anon_trigerred += 1
                anon_para_pairs[anonymized_command] = form
            paraphrasing_pairs = anon_para_pairs
            print(anon_trigerred, len(paraphrasing_pairs))
        pairs[0] = merge_dicts(pairs[0], paraphrasing_pairs)

    #pairs_in = [pairs_without_placeholders(rules, semantics, only_in_grammar=True) for _, rules, _, semantics in generator]
    by_command, by_form = determine_unique_cat_data(pairs)

    if args.use_form_split:
        data_to_split = by_form
    else:
        data_to_split = by_command
    train_pairs, test_pairs = get_pairs_by_cats(data_to_split, args.train_categories, args.test_categories)

    # Randomize for the split, but then sort by command length before we save out so that things are easier to read.
    # If these lists are the same, they need to be shuffled the same way...
    random.Random(args.seed).shuffle(train_pairs)
    random.Random(args.seed).shuffle(test_pairs)

    # Peg this split to match the split in another dataset. Helpful for making them mergeable while still preserving
    # the no-form-seen-before property of the form split
    if args.match_form_split:
        train_match = load_data(args.match_form_split + "/train.txt", cmd_gen.lambda_parser)
        train_match = set(train_match.values())
        val_match = load_data(args.match_form_split + "/val.txt", cmd_gen.lambda_parser)
        val_match = set(val_match.values())
        test_match = load_data(args.match_form_split + "/test.txt", cmd_gen.lambda_parser)
        test_match = set(test_match.values())
        train_percentage = len(train_match) / (len(train_match) + len(val_match) + len(test_match))
        val_percentage = len(val_match) / (len(train_match) + len(val_match) + len(test_match))
        test_percentage = len(test_match) / (len(train_match) + len(val_match) + len(test_match))
        train = []
        val = []
        test = []
        # TODO: Square this away with test dist param. Probably drop the cat params
        for form, commands in itertools.chain(train_pairs):
            target = None
            if form in train_match:
                target = train
            elif form in val_match:
                target = val
            elif form in test_match:
                target = test
            else:
                print(form)
                continue
                # assert False
            target.append((form, commands))
    else:
        train_percentage, val_percentage, test_percentage = args.split
        if different_test_dist:
            # Just one split for the first dist, then use all of test
            split1 = int(train_percentage * len(train_pairs))
            train, val, test = train_pairs[:split1], train_pairs[split1:], test_pairs
        else:
            # If we're training and testing on the same distributions, these should match exactly
            assert train_pairs == test_pairs
            split1 = int(train_percentage * len(train_pairs))
            split2 = int((train_percentage + val_percentage) * len(train_pairs))
            train, val, test = train_pairs[:split1], train_pairs[split1:split2], train_pairs[split2:]

    # Parse splits would have stored parse-(command list) pairs, so lets
    # flatten out those lists if we need to.
    if args.use_form_split:
        train = flatten(train)
        val = flatten(val)
        test = flatten(test)

    # With this switch, we'll simulate getting data one batch at a time
    # so we can assess how quickly we improve
    if args.incremental_datasets:
        limit = 16
        count = 1
        while limit < len(train):
            data_to_write = train[:limit]
            data_to_write = sorted(data_to_write, key=lambda x: len(x[0]))
            with open("".join(train_out_path.split(".")[:-1]) + str(count) + ".txt", "w") as f:
                for sentence, parse in data_to_write:
                    f.write(sentence + '\n' + str(parse) + '\n')
            limit += 16
            count += 1

    save_data(train, train_out_path)
    save_data(val, val_out_path)
    save_data(test, test_out_path)

    command_vocab = Counter()
    parse_vocab = Counter()
    for command, parse in itertools.chain(train, val, test):
        for token in command.split():
            command_vocab[token] += 1
        for token in parse.split():
            parse_vocab[token] += 1

    info = "Generated {} dataset with {:.2f}/{:.2f}/{:.2f} split\n".format(args.name, train_percentage, val_percentage, test_percentage)
    total_train_set = len(train) + len(val)
    if different_test_dist:
        total_test_set = len(test)
    else:
        total_train_set += len(test)
        total_test_set = total_train_set
    info += "Exact split percentage: {:.2f}/{:.2f}/{:.2f} split\n".format(len(train)/total_train_set, len(val)/total_train_set, len(test)/total_test_set)

    info += "train={} val={} test={}".format(len(train), len(val), len(test))
    print(info)
    with open(meta_out_path, "w") as f:
        f.write(info)

        f.write("\n\nUtterance vocab\n")
        for token, count in sorted(command_vocab.items(), key=operator.itemgetter(1), reverse=True):
            f.write("{} {}\n".format(token, str(count)))

        f.write("\n\nParse vocab\n")
        for token, count in sorted(parse_vocab.items(), key=operator.itemgetter(1), reverse=True):
            f.write("{} {}\n".format(token, str(count)))
コード例 #12
0
import os
from os.path import join

from gpsr_command_understanding.generator import Generator
from gpsr_command_understanding.grammar import tree_printer
from gpsr_command_understanding.loading_helpers import load_all_2018_by_cat
from gpsr_command_understanding.tokens import ROOT_SYMBOL
from gpsr_command_understanding.generation import generate_random_pair

import random

grammar_dir = os.path.abspath(
    os.path.dirname(__file__) + "/../../resources/generator2018/")
common_path = join(grammar_dir, "common_rules.txt")

paths = tuple(
    map(lambda x: join(grammar_dir, x),
        ["objects.xml", "locations.xml", "names.xml", "gestures.xml"]))

generator = Generator()
rules = load_all_2018_by_cat(generator, grammar_dir)

utterance, parse = generate_random_pair(ROOT_SYMBOL,
                                        rules[0][1],
                                        rules[0][3],
                                        random_generator=random.Random())
print(tree_printer(utterance))