コード例 #1
0
 def test_load_2018(self):
     generator = Generator(grammar_format_version=2018)
     all_2018 = load_all_2018_by_cat(generator,
                                     GRAMMAR_DIR_2018,
                                     expand_shorthand=False)
     # To manually inspect correctness for now...
     """for nonterm, rules in all_2018[0].items():
コード例 #2
0
 def test_generate(self):
     generator = Generator(grammar_format_version=2018)
     grammar = generator.load_rules(os.path.join(FIXTURE_DIR,
                                                 "grammar.txt"))
     semantics = generator.load_semantics_rules(
         os.path.join(FIXTURE_DIR, "semantics.txt"))
     pairs = list(
         generate_sentence_parse_pairs(NonTerminal("Main"), grammar,
                                       semantics))
     self.assertEqual(len(pairs), 6)
コード例 #3
0
 def test_parse_utterance(self):
     rules = {}
     generator = Generator(grammar_format_version=2019)
     grammar = generator.load_rules(os.path.join(FIXTURE_DIR,
                                                 "grammar.txt"),
                                    expand_shorthand=False)
     parser = GrammarBasedParser(grammar)
     test = parser("say hi to him right now please")
     print(test.pretty())
     test = parser("bring it to {pron} now")
     print(test.pretty())
コード例 #4
0
    def test_nearest_neighbor_parser(self):
        generator = Generator(grammar_format_version=2018)
        rules = load_all_2019(generator, GRAMMAR_DIR)

        sentences = generate_sentences(ROOT_SYMBOL, rules[0])
        parser = GrammarBasedParser(rules[0])
        sentences = list(set([tree_printer(x) for x in sentences]))
        neighbors = [(sentence, parser(sentence)) for sentence in sentences]
        nearest_neighbor_parser = KNearestNeighborParser(neighbors)
        some_sentence = sentences[0]
        tweaked = some_sentence[:-1]
        expected_parse = parser(some_sentence)
        self.assertEqual(nearest_neighbor_parser(some_sentence),
                         expected_parse)
        self.assertEqual(nearest_neighbor_parser(tweaked), expected_parse)
コード例 #5
0
def main():
    random_source = random.Random(seed)
    grammar_dir = os.path.abspath(
        os.path.dirname(__file__) + "/../../resources/generator2018")
    out_file_path = os.path.abspath(
        os.path.dirname(__file__) +
        "/../../data/rephrasings_data_{}_{}.csv".format(
            seed, groundings_per_parse))
    cmd_gen = Generator(grammar_format_version=2018)
    generator = load_all_2018_by_cat(cmd_gen, grammar_dir)

    all_examples = []
    for i in range(groundings_per_parse):
        grounded_examples = get_grounding_per_each_parse(
            generator, random_source)
        random_source.shuffle(grounded_examples)
        all_examples += grounded_examples

    with open(out_file_path, 'w') as csvfile:
        output = csv.writer(csvfile, quoting=csv.QUOTE_MINIMAL)
        command_columns = [("command" + str(x), "parse" + str(x),
                            "parse_ground" + str(x))
                           for x in range(1, rehprasings_per_hit + 1)]
        command_columns = [x for tuple in command_columns for x in tuple]
        output.writerow(command_columns)

        chunks = list(chunker(all_examples, rehprasings_per_hit))
        print("Writing {} HITS".format(len(chunks)))
        for i, chunk in enumerate(chunks):
            if len(chunk) < rehprasings_per_hit:
                needed = rehprasings_per_hit - len(chunk)
                # Sample from previous hits to fill out this last one
                chunk += random_source.sample(
                    [pair for chunk in chunks[:i] for pair in chunk], k=needed)
            line = []
            for utterance, parse_anon, parse_ground in chunk:
                line += [
                    tree_printer(utterance),
                    tree_printer(parse_anon),
                    tree_printer(parse_ground)
                ]
            output.writerow(line)

    # Let's verify that we can load the output back in...
    with open(out_file_path, 'r') as csvfile:
        input = csv.DictReader(csvfile)
        for line in input:
            pass
コード例 #6
0
    def test_parse_all_of_2019(self):
        generator = Generator(grammar_format_version=2018)

        grammar_dir = os.path.abspath(
            os.path.dirname(__file__) + "/../resources/generator2019")
        rules, rules_anon, _, _, _ = load_all_2019(generator, grammar_dir)

        sentences = generate_sentences(ROOT_SYMBOL, rules)
        parser = GrammarBasedParser(rules)
        sentences = set([tree_printer(x) for x in sentences])
        succeeded = 0
        for sentence in sentences:
            parsed = parser(sentence)
            if parsed:
                succeeded += 1

        self.assertEqual(len(sentences), succeeded)
コード例 #7
0
    def test_parse_all_2019_anonymized(self):
        generator = Generator(grammar_format_version=2019)

        grammar_dir = os.path.abspath(
            os.path.dirname(__file__) + "/../resources/generator2019")
        rules, rules_anon, rules_ground, semantics, entities = load_all_2019(
            generator, grammar_dir)

        sentences = generate_sentence_parse_pairs(
            ROOT_SYMBOL,
            rules_ground, {},
            yield_requires_semantics=False,
            random_generator=random.Random(1))
        parser = GrammarBasedParser(rules_anon)

        # Bring me the apple from the fridge to the kitchen
        # ---straight anon to clusters--->
        # Bring me the {ob}  from the {loc} to the {loc}
        # ---Grammar based parser--->
        # (Failure; grammar has numbers on locs)

        # Bring me the apple from the fridge to the kitchen
        # ---id naive number anon--->
        # Bring me the {ob}  from the {loc 1} to the {loc 2}
        # ---Grammar based parser--->
        # (Failure; wrong numbers, or maybe)

        anonymizer = Anonymizer(*entities)
        parser = AnonymizingParser(parser, anonymizer)
        num_tested = 1000
        succeeded = 0
        for sentence, parse in itertools.islice(sentences, num_tested):
            sentence = tree_printer(sentence)
            parsed = parser(sentence)
            if parsed:
                succeeded += 1
            else:
                print(sentence)
                print(anonymizer(sentence))
                print()
                print(parser(anonymizer(sentence)))

        self.assertEqual(succeeded, num_tested)
def main():
    assert len(sys.argv) == 4
    reader = Seq2SeqDatasetReader(source_tokenizer=NoOpTokenizer(),
                                  target_tokenizer=NoOpTokenizer())
    train = reader.read(sys.argv[1])
    val = reader.read(sys.argv[2])
    test = reader.read(sys.argv[3])

    generator = Generator()
    rules, rules_anon, rules_ground, semantics, entities = load_all_2018(
        generator, GRAMMAR_DIR)
    anonymizer = Anonymizer(*entities)

    neighbors = []
    for x in itertools.chain(train, val):
        command = str(x["source_tokens"][1:-1][0])
        form = str(x["target_tokens"][1:-1][0])
        anon_command = anonymizer(command)
        neighbors.append((anon_command, form))

    test_pairs = []
    for x in test:
        test_pairs.append((str(x["source_tokens"][1:-1][0]),
                           str(x["target_tokens"][1:-1][0])))

    print("Check grammar membership")
    naive_parser = GrammarBasedParser(rules_anon)
    anon_parser = AnonymizingParser(naive_parser, anonymizer)

    correct, parsed = bench_parser(anon_parser, test_pairs)
    print("Got {} of {} ({:.2f})".format(parsed, len(test_pairs),
                                         100.0 * parsed / len(test_pairs)))

    print("Jaccard distance")
    sweep_thresh(neighbors, test_pairs, anonymizer,
                 lambda x, y: jaccard_distance(set(x.split()), set(y.split())),
                 [0.1 * i for i in range(11)])
    print("Edit distance")
    sweep_thresh(neighbors, test_pairs, anonymizer, editdistance.eval)
コード例 #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-s",
                        "--split",
                        default=[.7, .1, .2],
                        nargs='+',
                        type=float)
    parser.add_argument("-trc",
                        "--train-categories",
                        default=[1, 2, 3],
                        nargs='+',
                        type=int)
    parser.add_argument("-tc",
                        "--test-categories",
                        default=[1, 2, 3],
                        nargs='+',
                        type=int)
    parser.add_argument("-p",
                        "--use-parse-split",
                        action='store_true',
                        default=False)
    parser.add_argument("-b", "--branch-cap", default=None, type=int)
    parser.add_argument("-t", "--turk", required=False, default=None, type=str)
    parser.add_argument("--name", default=None, type=str)
    parser.add_argument("--seed", default=0, required=False, type=int)
    parser.add_argument("-f",
                        "--force-overwrite",
                        action="store_true",
                        required=False,
                        default=False)
    args = parser.parse_args()

    validate_args(args)

    cmd_gen = Generator(grammar_format_version=2018,
                        semantic_form_version="slot")
    #cmd_gen = Generator(grammar_format_version=2018, semantic_form_version="lambda")
    random_source = random.Random(args.seed)

    pairs_out_path = os.path.join(
        os.path.abspath(os.path.dirname(__file__) + "/../.."), "data",
        args.name)
    train_out_path = os.path.join(pairs_out_path, "train.txt")
    val_out_path = os.path.join(pairs_out_path, "val.txt")
    test_out_path = os.path.join(pairs_out_path, "test.txt")
    meta_out_path = os.path.join(pairs_out_path, "meta.txt")

    if args.force_overwrite and os.path.isdir(pairs_out_path):
        shutil.rmtree(pairs_out_path)
    os.mkdir(pairs_out_path)

    grammar_dir = os.path.abspath(
        os.path.dirname(__file__) + "/../../resources/generator2018")
    common_path = join(grammar_dir, "common_rules.txt")
    paths = tuple(
        map(lambda x: join(grammar_dir, x),
            ["objects.xml", "locations.xml", "names.xml", "gestures.xml"]))
    grammar_file_paths = [
        common_path,
        join(grammar_dir, "gpsr_category_1_grammar.txt")
    ]
    semantics_file_paths = [
        join(grammar_dir, "gpsr_category_1_slot.txt"),
        join(grammar_dir, "common_rules_slot.txt")
    ]

    cmd_gen.load_set_of_rules(grammar_file_paths, semantics_file_paths, *paths)

    grammar_file_paths = [
        common_path,
        join(grammar_dir, "gpsr_category_2_grammar.txt")
    ]
    semantics_file_paths = [
        join(grammar_dir, "gpsr_category_2_slot.txt"),
        join(grammar_dir, "common_rules_slot.txt")
    ]

    cmd_gen.load_set_of_rules(grammar_file_paths, semantics_file_paths, *paths)

    grammar_file_paths = [
        common_path,
        join(grammar_dir, "gpsr_category_3_grammar.txt")
    ]
    semantics_file_paths = [
        join(grammar_dir, "gpsr_category_3_slot.txt"),
        join(grammar_dir, "common_rules_slot.txt")
    ]

    cmd_gen.load_set_of_rules(grammar_file_paths, semantics_file_paths, *paths)

    #generator = cmd_gen.rules[2]
    #for k,v in generator[2].items():
    #    print(k)
    #    print(v)
    #    print("-------------------------------------------------------------------")

    pairs = []
    pairs.append(
        cmd_gen.get_utterance_semantics_pairs(random_source, [1],
                                              args.branch_cap))
    pairs.append(
        cmd_gen.get_utterance_semantics_pairs(random_source, [2],
                                              args.branch_cap))
    pairs.append(
        cmd_gen.get_utterance_semantics_pairs(random_source, [3],
                                              args.branch_cap))
    #pairs = [cmd_gen.get_utterance_semantics_pairs(random_source, [cat], args.branch_cap) for cat in [1, 2, 3]]

    #if args.turk and len(args.train_categories) == 3:
    #    turk_pairs = load_turk_data(args.turk, cmd_gen.lambda_parser)
    #    pairs[0] = merge_dicts(pairs[0], turk_pairs)

    by_utterance, by_parse = determine_unique_cat_data(pairs)

    if args.use_parse_split:
        data_to_split = by_parse
    else:
        data_to_split = by_utterance
    train_pairs, test_pairs = get_pairs_by_cats(data_to_split,
                                                args.train_categories,
                                                args.test_categories)

    # Randomize for the split, but then sort by utterance length before we save out so that things are easier to read.
    # If these lists are the same, they need to be shuffled the same way...
    random.Random(args.seed).shuffle(train_pairs)
    random.Random(args.seed).shuffle(test_pairs)

    if args.test_categories == args.train_categories:
        # If we're training and testing on the same distributions, these should match exactly
        assert train_pairs == test_pairs

    different_test_dist = False
    if args.test_categories != args.train_categories:
        different_test_dist = True

    train_percentage, val_percentage, test_percentage = args.split
    if different_test_dist:
        # Just one split for the first dist, then use all of test
        split1 = int(train_percentage * len(train_pairs))
        train, val, test = train_pairs[:split1], train_pairs[
            split1:], test_pairs
    else:
        split1 = int(train_percentage * len(train_pairs))
        split2 = int((train_percentage + val_percentage) * len(train_pairs))
        train, val, test = train_pairs[:split1], train_pairs[
            split1:split2], train_pairs[split2:]

    # Parse splits would have stored parse-(utterance list) pairs, so lets
    # flatten out those lists if we need to.
    if args.use_parse_split:
        train = flatten(train)
        val = flatten(val)
        test = flatten(test)

    save_slot_data(train, train_out_path)
    save_slot_data(val, val_out_path)
    save_slot_data(test, test_out_path)

    utterance_vocab = Counter()
    parse_vocab = Counter()
    for utterance, parse in itertools.chain(train, val, test):
        for token in utterance.split(" "):
            utterance_vocab[token] += 1
        for token in parse.split(" "):
            parse_vocab[token] += 1

    info = "Generated {} dataset with {:.2f}/{:.2f}/{:.2f} split\n".format(
        args.name, train_percentage, val_percentage, test_percentage)
    total = len(train) + len(val) + len(test)
    info += "Exact split percentage: {:.2f}/{:.2f}/{:.2f} split\n".format(
        len(train) / total,
        len(val) / total,
        len(test) / total)

    info += "train={} val={} test={}".format(len(train), len(val), len(test))
    print(info)
    with open(meta_out_path, "w") as f:
        f.write(info)

        f.write("\n\nUtterance vocab\n")
        for token, count in sorted(utterance_vocab.items(),
                                   key=operator.itemgetter(1),
                                   reverse=True):
            f.write("{} {}\n".format(token, str(count)))

        f.write("\n\nParse vocab\n")
        for token, count in sorted(parse_vocab.items(),
                                   key=operator.itemgetter(1),
                                   reverse=True):
            f.write("{} {}\n".format(token, str(count)))
コード例 #10
0
def main():
    out_root = os.path.abspath(os.path.dirname(__file__) + "/../../data")
    grammar_dir = os.path.abspath(
        os.path.dirname(__file__) + "/../../resources/generator2018")

    cmd_gen = Generator(grammar_format_version=2018)
    generator = load_all_2018_by_cat(cmd_gen, grammar_dir)

    cat_sentences = [
        set(generate_sentences(ROOT_SYMBOL, rules))
        for _, rules, _, _ in generator
    ]
    pairs = [{
        utterance: parse
        for utterance, parse in expand_all_semantics(rules, semantics)
    } for _, rules, _, semantics in generator]

    pairs = [
        pairs_without_placeholders(rules, semantics)
        for _, rules, _, semantics in generator
    ]
    by_utterance, by_parse = determine_unique_cat_data(
        pairs, keep_new_utterance_repeat_parse_for_lower_cat=False)
    unique = []
    for i, _ in enumerate(cat_sentences):
        prev_cats = cat_sentences[:i]
        if prev_cats:
            # Don't count the sentence as unique unless it hasn't happened in any earlier categories
            prev_cat_sentences = set().union(*prev_cats)
            overlapped_with_prev = cat_sentences[i].intersection(
                prev_cat_sentences)
            unique.append(cat_sentences[i].difference(prev_cat_sentences))
        else:
            unique.append(cat_sentences[i])
    # Sets should be disjoint
    assert (len(set().union(*unique)) == sum([len(cat) for cat in unique]))

    all_sentences = [tree_printer(x) for x in set().union(*cat_sentences)]
    all_pairs = pairs

    annotated = [get_annotated_sentences(x) for x in zip(cat_sentences, pairs)]

    unique_annotated = [
        get_annotated_sentences((unique_sen, cat_pairs))
        for unique_sen, cat_pairs in zip(unique, pairs)
    ]
    unique_sentence_parses = [[
        pairs[ann_sen] for ann_sen in annotated
    ] for annotated, pairs in zip(unique_annotated, pairs)]
    unique_sentence_parses = [set(x) for x in unique_sentence_parses]

    combined_annotations = set().union(*annotated)
    combined_annotations.intersection_update(all_sentences)

    parseless = [
        sen.difference(annotated_sentences)
        for sen, annotated_sentences in zip(cat_sentences, annotated)
    ]

    out_paths = [
        join(out_root,
             str(i) + "_sentences.txt") for i in range(1, 4)
    ]

    for cat_out_path, sentences in zip(out_paths, cat_sentences):
        with open(cat_out_path, "w") as f:
            for sentence in sentences:
                assert not has_placeholders(sentence)
                f.write(tree_printer(sentence) + '\n')

    out_paths = [join(out_root, str(i) + "_pairs.txt") for i in range(1, 4)]

    for cat_out_path, pairs in zip(out_paths, all_pairs):
        with open(cat_out_path, "w") as f:
            for sentence, parse in pairs.items():
                f.write(sentence + '\n' + parse + '\n')

    meta_out_path = join(out_root, "annotations_meta.txt")
    with open(meta_out_path, "w") as f:
        f.write("Coverage:\n")
        cat_parse_lengths = []
        cat_filtered_parse_lengths = []
        cat_sen_lengths = []
        for i, (annotated_sen, sen, unique_parses) in enumerate(
                zip(annotated, cat_sentences, unique_sentence_parses)):
            f.write("cat{0} {1}/{2} {3:.1f}%\n".format(
                i + 1, len(annotated_sen), len(sen),
                100.0 * len(annotated_sen) / len(sen)))
            f.write("\t unique parses: {}\n".format(len(unique_parses)))
            cat_sen_lengths.append(
                [len(tree_printer(sentence).split()) for sentence in sen])
            avg_sentence_length = np.mean(cat_sen_lengths[i])
            parse_lengths = []
            filtered_parse_lengths = []
            for parse in unique_parses:
                parse_lengths.append(len(parse.split()))
                stop_tokens_removed = re.sub("(\ e\ |\"|\)|\()", "", parse)
                filtered_parse_lengths.append(len(stop_tokens_removed.split()))
            cat_parse_lengths.append(parse_lengths)
            cat_filtered_parse_lengths.append(filtered_parse_lengths)
            avg_parse_length = np.mean(cat_parse_lengths[i])
            avg_filtered_parse_length = np.mean(cat_filtered_parse_lengths[i])
            f.write(
                "\t avg sentence length (tokens): {:.1f} avg parse length (tokens): {:.1f} avg filtered parse length (tokens): {:.1f}\n"
                .format(avg_sentence_length, avg_parse_length,
                        avg_filtered_parse_length))

        f.write("combined {0}/{1} {2:.1f}%\n".format(
            len(combined_annotations), len(all_sentences),
            100.0 * len(combined_annotations) / len(all_sentences)))
        f.write("combined unique parses: {}\n".format(
            len(set().union(*unique_sentence_parses))))

        all_sen_lengths = [length for cat in cat_sen_lengths for length in cat]
        all_parse_lengths = [
            length for cat in cat_parse_lengths for length in cat
        ]
        all_filtered_parse_lengths = [
            length for cat in cat_filtered_parse_lengths for length in cat
        ]
        f.write(
            "combined avg sentence length (tokens): {:.1f} avg parse length (tokens): {:.1f} avg filtered parse length (tokens): {:.1f}\n"
            .format(np.mean(all_sen_lengths), np.mean(all_parse_lengths),
                    np.mean(all_filtered_parse_lengths)))
    print("No parses for:")
    for cat in parseless:
        for sentence in sorted(map(tree_printer, cat)):
            print(sentence)
        print("-----------------")
コード例 #11
0
 def setUp(self) -> None:
     self.generator = Generator(grammar_format_version=2019)
コード例 #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-s","--split", default=[.7,.1,.2], nargs='+', type=float)
    parser.add_argument("-trc","--train-categories", default=[1, 2, 3], nargs='+', type=int)
    parser.add_argument("-tc","--test-categories", default=[1, 2, 3], nargs='+', type=int)
    parser.add_argument("-p", "--use-form-split", action='store_true', default=False)
    parser.add_argument("-g","--groundings", required=False, type=int, default=None)
    parser.add_argument("-a","--anonymized", required=False, default=True, action="store_true")
    parser.add_argument("-m", "--match-form-split", required=False, default=None, type=str)
    parser.add_argument("-na","--no-anonymized", required=False, dest="anonymized", action="store_false")
    parser.add_argument("-ra", "--run-anonymizer", required=False, default=False, action="store_true")
    parser.add_argument("-t", "--paraphrasings", required=False, default=None, type=str)
    parser.add_argument("--name", default=None, type=str)
    parser.add_argument("--seed", default=0, required=False, type=int)
    parser.add_argument("-i","--incremental-datasets", action='store_true', required=False)
    parser.add_argument("-f", "--force-overwrite", action="store_true", required=False, default=False)
    args = parser.parse_args()

    validate_args(args)

    cmd_gen = Generator(grammar_format_version=2018)
    random_source = random.Random(args.seed)

    different_test_dist = (args.test_categories != args.train_categories)

    pairs_out_path = os.path.join(os.path.abspath(os.path.dirname(__file__) + "/../.."), "data", args.name)
    train_out_path = os.path.join(pairs_out_path, "train.txt")
    val_out_path = os.path.join(pairs_out_path, "val.txt")
    test_out_path = os.path.join(pairs_out_path, "test.txt")
    meta_out_path = os.path.join(pairs_out_path, "meta.txt")

    if args.force_overwrite and os.path.isdir(pairs_out_path):
        shutil.rmtree(pairs_out_path)
    os.mkdir(pairs_out_path)
    
    grammar_dir = os.path.abspath(os.path.dirname(__file__) + "/../../resources/generator2018")

    generator = load_all_2018_by_cat(cmd_gen, grammar_dir)

    pairs = [{}, {}, {}]
    if args.anonymized:
        pairs = [pairs_without_placeholders(rules, semantics) for _, rules, _, semantics in generator]

    # For now this only works with all data
    if args.groundings and len(args.train_categories) == 3:
        for i in range(args.groundings):
            groundings = get_grounding_per_each_parse_by_cat(generator,random_source)
            for cat_pairs, groundings in zip(pairs, groundings):
                for utt, form_anon, _ in groundings:
                    pairs[0][tree_printer(utt)] = tree_printer(form_anon)

    if args.paraphrasings and len(args.train_categories) == 3:
        paraphrasing_pairs = load_data(args.paraphrasings, cmd_gen.lambda_parser)
        if args.run_anonymizer:
            paths = tuple(
                map(lambda x: join(grammar_dir, x), ["objects.xml", "locations.xml", "names.xml", "gestures.xml"]))
            entities = load_entities_from_xml(*paths)
            anonymizer = Anonymizer(*entities)
            anon_para_pairs = {}
            anon_trigerred = 0
            for command, form in paraphrasing_pairs.items():
                anonymized_command = anonymizer(command)
                if anonymized_command != command:
                    anon_trigerred += 1
                anon_para_pairs[anonymized_command] = form
            paraphrasing_pairs = anon_para_pairs
            print(anon_trigerred, len(paraphrasing_pairs))
        pairs[0] = merge_dicts(pairs[0], paraphrasing_pairs)

    #pairs_in = [pairs_without_placeholders(rules, semantics, only_in_grammar=True) for _, rules, _, semantics in generator]
    by_command, by_form = determine_unique_cat_data(pairs)

    if args.use_form_split:
        data_to_split = by_form
    else:
        data_to_split = by_command
    train_pairs, test_pairs = get_pairs_by_cats(data_to_split, args.train_categories, args.test_categories)

    # Randomize for the split, but then sort by command length before we save out so that things are easier to read.
    # If these lists are the same, they need to be shuffled the same way...
    random.Random(args.seed).shuffle(train_pairs)
    random.Random(args.seed).shuffle(test_pairs)

    # Peg this split to match the split in another dataset. Helpful for making them mergeable while still preserving
    # the no-form-seen-before property of the form split
    if args.match_form_split:
        train_match = load_data(args.match_form_split + "/train.txt", cmd_gen.lambda_parser)
        train_match = set(train_match.values())
        val_match = load_data(args.match_form_split + "/val.txt", cmd_gen.lambda_parser)
        val_match = set(val_match.values())
        test_match = load_data(args.match_form_split + "/test.txt", cmd_gen.lambda_parser)
        test_match = set(test_match.values())
        train_percentage = len(train_match) / (len(train_match) + len(val_match) + len(test_match))
        val_percentage = len(val_match) / (len(train_match) + len(val_match) + len(test_match))
        test_percentage = len(test_match) / (len(train_match) + len(val_match) + len(test_match))
        train = []
        val = []
        test = []
        # TODO: Square this away with test dist param. Probably drop the cat params
        for form, commands in itertools.chain(train_pairs):
            target = None
            if form in train_match:
                target = train
            elif form in val_match:
                target = val
            elif form in test_match:
                target = test
            else:
                print(form)
                continue
                # assert False
            target.append((form, commands))
    else:
        train_percentage, val_percentage, test_percentage = args.split
        if different_test_dist:
            # Just one split for the first dist, then use all of test
            split1 = int(train_percentage * len(train_pairs))
            train, val, test = train_pairs[:split1], train_pairs[split1:], test_pairs
        else:
            # If we're training and testing on the same distributions, these should match exactly
            assert train_pairs == test_pairs
            split1 = int(train_percentage * len(train_pairs))
            split2 = int((train_percentage + val_percentage) * len(train_pairs))
            train, val, test = train_pairs[:split1], train_pairs[split1:split2], train_pairs[split2:]

    # Parse splits would have stored parse-(command list) pairs, so lets
    # flatten out those lists if we need to.
    if args.use_form_split:
        train = flatten(train)
        val = flatten(val)
        test = flatten(test)

    # With this switch, we'll simulate getting data one batch at a time
    # so we can assess how quickly we improve
    if args.incremental_datasets:
        limit = 16
        count = 1
        while limit < len(train):
            data_to_write = train[:limit]
            data_to_write = sorted(data_to_write, key=lambda x: len(x[0]))
            with open("".join(train_out_path.split(".")[:-1]) + str(count) + ".txt", "w") as f:
                for sentence, parse in data_to_write:
                    f.write(sentence + '\n' + str(parse) + '\n')
            limit += 16
            count += 1

    save_data(train, train_out_path)
    save_data(val, val_out_path)
    save_data(test, test_out_path)

    command_vocab = Counter()
    parse_vocab = Counter()
    for command, parse in itertools.chain(train, val, test):
        for token in command.split():
            command_vocab[token] += 1
        for token in parse.split():
            parse_vocab[token] += 1

    info = "Generated {} dataset with {:.2f}/{:.2f}/{:.2f} split\n".format(args.name, train_percentage, val_percentage, test_percentage)
    total_train_set = len(train) + len(val)
    if different_test_dist:
        total_test_set = len(test)
    else:
        total_train_set += len(test)
        total_test_set = total_train_set
    info += "Exact split percentage: {:.2f}/{:.2f}/{:.2f} split\n".format(len(train)/total_train_set, len(val)/total_train_set, len(test)/total_test_set)

    info += "train={} val={} test={}".format(len(train), len(val), len(test))
    print(info)
    with open(meta_out_path, "w") as f:
        f.write(info)

        f.write("\n\nUtterance vocab\n")
        for token, count in sorted(command_vocab.items(), key=operator.itemgetter(1), reverse=True):
            f.write("{} {}\n".format(token, str(count)))

        f.write("\n\nParse vocab\n")
        for token, count in sorted(parse_vocab.items(), key=operator.itemgetter(1), reverse=True):
            f.write("{} {}\n".format(token, str(count)))
コード例 #13
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention = None,
                 attention_function: SimilarityFunction = None,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.5) -> None:
        super(Seq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        self._token_based_metrics = [TokenSequenceAccuracy(), ParseValidity(Generator().lambda_parser)]

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError("You can only specify an attention module or an "
                                         "attention function, but not both.")
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
コード例 #14
0
import os
from os.path import join

from gpsr_command_understanding.generator import Generator
from gpsr_command_understanding.grammar import tree_printer
from gpsr_command_understanding.loading_helpers import load_all_2018_by_cat
from gpsr_command_understanding.tokens import ROOT_SYMBOL
from gpsr_command_understanding.generation import generate_random_pair

import random

grammar_dir = os.path.abspath(
    os.path.dirname(__file__) + "/../../resources/generator2018/")
common_path = join(grammar_dir, "common_rules.txt")

paths = tuple(
    map(lambda x: join(grammar_dir, x),
        ["objects.xml", "locations.xml", "names.xml", "gestures.xml"]))

generator = Generator()
rules = load_all_2018_by_cat(generator, grammar_dir)

utterance, parse = generate_random_pair(ROOT_SYMBOL,
                                        rules[0][1],
                                        rules[0][3],
                                        random_generator=random.Random())
print(tree_printer(utterance))